diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..cdf5b64d --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch +*.ipch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# vim tags +tags +.tags +.*.swp + +# Editors +.vscode + +# build-in-source directory +build* + +# emacs temporary/backup files +.\#* +\#*\# +*~ + +# GDB temporary files +.gdb_history diff --git a/CMakeLists.txt b/CMakeLists.txt index 306e6ca6..9f706207 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,26 @@ -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.14) + +# Check support for CUDA/HIP in Cmake project(composable_kernel) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") +enable_testing() + +set(ROCM_SYMLINK_LIBS OFF) +find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm) + +include(ROCMInstallTargets) +include(ROCMPackageConfigHelpers) +include(ROCMSetupVersion) +include(ROCMInstallSymlinks) +include(ROCMCreatePackage) include(CheckCXXCompilerFlag) +rocm_setup_version(VERSION 0.2.0) +include(TargetFlags) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) + ## C++ enable_language(CXX) set(CMAKE_CXX_STANDARD 17) @@ -30,35 +46,42 @@ message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") -set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY}) ## HIP find_package(HIP REQUIRED) -message(STATUS "Build with HIP ${hip_VERSION}") - -## half -#find_path(HALF_INCLUDE_DIR half.hpp) -message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") - -# CMAKE_CXX_FLAGS -SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") -if(BUILD_DEV) - string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") +# Override HIP version in config.h, if necessary. +# The variables set by find_package() can't be overwritten, +# therefore let's use intermediate variables. +set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}") +set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}") +set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}") +if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR ) + set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}") + message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}") endif() -message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR ) + set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}") + message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) + set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}") + message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") +endif() +message(STATUS "Build with HIP ${HIP_VERSION}") ## tidy include(EnableCompilerWarnings) -set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) +set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") - set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) + set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) # Enable tidy on hip -elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU") - set(MIOPEN_TIDY_ERRORS ALL) +elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") + set(CK_TIDY_ERRORS ALL) endif() + include(ClangTidy) enable_clang_tidy( CHECKS @@ -150,13 +173,12 @@ enable_clang_tidy( -cppcoreguidelines-narrowing-conversions -altera-struct-pack-align -cppcoreguidelines-prefer-member-initializer - - ${MIOPEN_TIDY_CHECKS} - ${MIOPEN_TIDY_ERRORS} + ${CK_TIDY_CHECKS} + ${CK_TIDY_ERRORS} HEADER_FILTER "\.hpp$" EXTRA_ARGS - -DMIOPEN_USE_CLANG_TIDY + -DCK_USE_CLANG_TIDY ) include(CppCheck) @@ -180,19 +202,74 @@ enable_cppcheck( unmatchedSuppression FORCE SOURCES - host/host_tensor/src - host/driver_offline/src - composable_kernel/src/kernel_wrapper + library/src INCLUDE - host/host_tensor/include - host/solver/include - host/driver_offline/include - composable_kernel/include/* ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/library/include DEFINE CPPCHECK=1 __linux__=1 ) -add_subdirectory(host) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) + +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + + +SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") +if(BUILD_DEV) + add_compile_options(-Werror) + add_compile_options(-Weverything) +endif() +message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + +rocm_package_setup_component(tests + LIBRARY_NAME composablekernel + PACKAGE_NAME tests # Prevent -static suffix on package name +) + +add_subdirectory(library) +add_subdirectory(example) +add_subdirectory(test) +add_subdirectory(profiler) + +#Create an interface target for the include only files and call it "composablekernels" +include(CMakePackageConfigHelpers) + +set(version 1.0.0) +write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + VERSION "${version}" + COMPATIBILITY AnyNewerVersion +) + +configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + NO_CHECK_REQUIRED_COMPONENTS_MACRO +) + +rocm_install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) + +set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") +set(CPACK_RPM_PACKAGE_LICENSE "MIT") + +rocm_create_package( + NAME composablekernel + DESCRIPTION "High Performance Composable Kernel for AMD GPUs" + MAINTAINER "MIOpen Kernels Dev Team " + LDCONFIG + HEADER_ONLY +) diff --git a/Config.cmake.in b/Config.cmake.in new file mode 100644 index 00000000..12b5c331 --- /dev/null +++ b/Config.cmake.in @@ -0,0 +1,11 @@ +@PACKAGE_INIT@ + +set(_composable_kernel_supported_components device_operations host_tensor) + +foreach(_comp ${composable_kernel_FIND_COMPONENTS}) + if(NOT _comp IN_LIST _composable_kernel_supported_components) + set(composable_kernel_FOUND False) + set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") + endif() + include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") +endforeach() diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..0d32b52f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,95 @@ +FROM ubuntu:18.04 + +ARG ROCMVERSION=5.1 +ARG OSDB_BKC_VERSION + +RUN set -xe + +ARG BUILD_THREADS=8 +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +# Add rocm repository +RUN apt-get update +RUN apt-get install -y wget gnupg +RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - +RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" +RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - +RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" + +# ADD requirements.txt requirements.txt +# Install dependencies +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + apt-utils \ + build-essential \ + cmake-data=3.15.1-0kitware1 \ + cmake=3.15.1-0kitware1 \ + curl \ + g++ \ + gdb \ + git \ + hip-rocclr \ + jq \ + libelf-dev \ + libncurses5-dev \ + libnuma-dev \ + libpthread-stubs0-dev \ + llvm-amdgpu \ + pkg-config \ + python \ + python3.8 \ + python-dev \ + python3-dev \ + python-pip \ + python3-pip \ + software-properties-common \ + wget \ + rocm-dev \ + rocm-device-libs \ + rocm-cmake \ + vim \ + zlib1g-dev \ + openssh-server \ + clang-format-10 \ + kmod && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Setup ubsan environment to printstacktrace +RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer +ENV UBSAN_OPTIONS=print_stacktrace=1 + +# Install an init system +RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb +RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb + +# Install cget +RUN pip install cget + +# Install rclone +RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz + +ARG PREFIX=/opt/rocm +# Install dependencies +RUN cget install pfultz2/rocm-recipes +# Install rbuild +RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/6d78a0553babdaea8d2da5de15cbda7e869594b8.tar.gz +# Install packages for processing the performance results +RUN pip3 install --upgrade pip +RUN pip3 install sqlalchemy +RUN pip3 install pymysql +RUN pip3 install pandas +RUN pip3 install setuptools-rust +RUN pip3 install sshtunnel +# Setup ubsan environment to printstacktrace +ENV UBSAN_OPTIONS=print_stacktrace=1 + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 +ADD rbuild.ini /rbuild.ini +ADD dev-requirements.txt dev-requirements.txt +RUN rbuild prepare -s develop -d $PREFIX +RUN groupadd -f render + +# Install the new rocm-cmake version +RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 00000000..15be3e54 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,441 @@ +def rocmnode(name) { + return 'rocmtest && miopen && ' + name +} + +def show_node_info() { + sh """ + echo "NODE_NAME = \$NODE_NAME" + lsb_release -sd + uname -r + ls /opt/ -la + """ +} + +def cmake_build(Map conf=[:]){ + + def compiler = conf.get("compiler","/opt/rocm/bin/hipcc") + def config_targets = conf.get("config_targets","check") + def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined " + conf.get("extradebugflags", "") + def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","") + def prefixpath = conf.get("prefixpath","/opt/rocm") + def setup_args = conf.get("setup_args","") + + if (prefixpath != "/usr/local"){ + setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " + } + + def build_type_debug = (conf.get("build_type",'release') == 'debug') + + //cmake_env can overwrite default CXX variables. + def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") + + def package_build = (conf.get("package_build","") == "true") + + if (package_build == true) { + config_targets = "package" + } + + if(conf.get("build_install","") == "true") + { + config_targets = 'install ' + config_targets + setup_args = ' -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install' + setup_args + } else{ + setup_args = ' -DBUILD_DEV=On' + setup_args + } + + if(build_type_debug){ + setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args + }else{ + setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args + } + + def pre_setup_cmd = """ + echo \$HSA_ENABLE_SDMA + ulimit -c unlimited + rm -rf build + mkdir build + rm -rf install + mkdir install + cd build + """ + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + // reduce parallelism when compiling, clang uses too much memory + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 1 )) ${config_targets}") + def execute_cmd = conf.get("execute_cmd", "") + + def cmd = conf.get("cmd", """ + ${pre_setup_cmd} + ${setup_cmd} + ${build_cmd} + ${execute_cmd} + """) + + echo cmd + sh cmd + + // Only archive from master or develop + if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) { + archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true + } +} + +def buildHipClangJob(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = "composable_kernels" + def prefixpath = conf.get("prefixpath", "/opt/rocm") + def gpu_arch = conf.get("gpu_arch", "gfx908") + + // Jenkins is complaining about the render group + // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1" + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' " + + def variant = env.STAGE_NAME + + def retimage + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + if (params.USE_DOCKERFILE){ + try { + retimage = docker.build("${image}", dockerArgs + '.') + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + "--no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + } + else{ + timeout(time: 3, unit: 'HOURS'){ + retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull() + image="b56f8ac0d6ea" + sh "docker images" + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 5, unit: 'HOURS') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + cmake_build(conf) + } + } + } + return retimage +} + +def reboot(){ + build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),] +} + + + + + +def buildHipClangJobAndReboot(Map conf=[:]){ + try{ + buildHipClangJob(conf) + } + catch(e){ + echo "throwing error exception for the stage" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + + +def runCKProfiler(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = "composable_kernels" + def prefixpath = conf.get("prefixpath", "/opt/rocm") + def gpu_arch = conf.get("gpu_arch", "gfx908") + + // Jenkins is complaining about the render group + // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1" + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' " + + def variant = env.STAGE_NAME + + def retimage + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + if (params.USE_DOCKERFILE){ + try { + retimage = docker.build("${image}", dockerArgs + '.') + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + "--no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + } + else{ + timeout(time: 3, unit: 'HOURS'){ + retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull() + image="b56f8ac0d6ea" + sh "docker images" + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 5, unit: 'HOURS') + { + cmake_build(conf) + dir("script"){ + //run gemm performance tests + def gemm_log = "perf_gemm_${gpu_arch}.log" + sh "rm -f ${gemm_log}" + sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}" + sh "echo Node name: ${NODE_NAME} >> ${gemm_log}" + sh "echo GPU_arch name: ${gpu_arch} >> ${gemm_log}" + sh "rocminfo | grep 'Compute Unit:' >> ${gemm_log} " + sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}" + sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}" + sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}" + //results will be parsed, stored, and analyzed within the python script + //the script will return 0 if the performance criteria are met + //or return 1 if the criteria are not met + archiveArtifacts "${gemm_log}" + sh "python3 parse_perf_data.py ${gemm_log} " + //run resnet50 test + def resnet_log = "perf_resnet50_${gpu_arch}.log" + sh "rm -f ${resnet_log}" + sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}" + sh "echo Node name: ${NODE_NAME} >> ${resnet_log}" + sh "echo GPU_arch name: ${gpu_arch} >> ${resnet_log}" + sh "rocminfo | grep 'Compute Unit:' >> ${resnet_log} " + sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}" + sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}" + //first run tests with N=256 + sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}" + //then run with N=4 + sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}" + archiveArtifacts "${resnet_log}" + //the script will put the results from N=256 and N=4 runs into separate tables + sh "python3 parse_perf_data.py ${resnet_log} " + } + } + } + } + return retimage +} + + +def runPerfTest(Map conf=[:]){ + try{ + runCKProfiler(conf) + } + catch(e){ + echo "throwing error exception in performance tests" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + +pipeline { + agent none + options { + parallelsAlwaysFailFast() + } + parameters { + booleanParam( + name: "USE_DOCKERFILE", + defaultValue: true, + description: "") + } + environment{ + dbuser = "${dbuser}" + dbpassword = "${dbpassword}" + dbsship = "${dbsship}" + dbsshport = "${dbsshport}" + dbsshuser = "${dbsshuser}" + dbsshpassword = "${dbsshpassword}" + status_wrapper_creds = "${status_wrapper_creds}" + } + stages{ + stage("Static checks") { + parallel{ + // enable after we move from hipcc to hip-clang + // stage('Tidy') { + // agent{ label rocmnode("nogpu") } + // environment{ + // // setup_cmd = "CXX='/opt/rocm/bin/hipcc' cmake -DBUILD_DEV=On .. " + // build_cmd = "make -j\$(nproc) -k analyze" + // } + // steps{ + // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug') + // } + // } + stage('Clang Format') { + agent{ label rocmnode("nogpu") } + environment{ + execute_cmd = "find .. -iname \'*.h\' \ + -o -iname \'*.hpp\' \ + -o -iname \'*.cpp\' \ + -o -iname \'*.h.in\' \ + -o -iname \'*.hpp.in\' \ + -o -iname \'*.cpp.in\' \ + -o -iname \'*.cl\' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'" + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + } + } + } + } + stage("Tests") + { + parallel + { + stage("Run Tests: gfx908") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908") + } + } + stage("Run Tests: gfx90a") + { + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a") + } + } + } + } + stage("Client App") + { + parallel + { + stage("Run Client App") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc .. && make -j """ + } + steps{ + buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } + } + } + stage("Performance Tests") + { + parallel + { + stage("Run ckProfiler: gfx908") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908") + } + } + stage("Run ckProfiler: gfx90a") + { + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a") + } + } + } + } + /* enable after the cmake file supports packaging + stage("Packages") { + when { + expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA } + } + parallel { + stage("Package /opt/rocm") { + agent{ label rocmnode("nogpu") } + steps{ + buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a") + } + } + } + } + */ + } +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..2fe9a845 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4f071d58..bbc4d2bc 100644 --- a/README.md +++ b/README.md @@ -1,177 +1,66 @@ -# How to build and run - -# Docker -``` -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.2-tf2.4-dev \ +## Docker script +```bash +docker run \ +-it \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm5.1-tf2.6-dev \ /bin/bash ``` -# Install Boost for online compilation -https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install +# Install newer version of rocm-cmake +https://github.com/RadeonOpenCompute/rocm-cmake - -# Build -Add path of Boost -``` - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH -``` - -``` +## Build +```bash mkdir build && cd build ``` -cmake cmd. Need to Specify target ID, example below is gfx908 -``` -cmake \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ --D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +```bash +# Need to specify target ID, example below is gfx908 and gfx90a +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_INSTALL_PREFIX=${PATH_TO_CK_INSTALL_DIRECTORY} \ .. ``` -Build drivers: \ -``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \ -``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \ -``conv_fwd_driver_online`` is (online compilation) driver for forward convolution -``` - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online +### Build and Run Examples +```bash + make -j examples ``` +Instructions for running each individual examples are under ```example/``` -# Run -* layout: 0 = NCHW; 1 = NHWC -* algo: algorithm -* verify: 0 = no verification; 1 = do verification -* init: 0 ~ 5. initialization method -* log: 0 = no log; 1 = do log -* repeat: number of time kernel being launched -``` -######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 +## Tests +```bash + make -j examples tests + make test ``` -# Result -Forward convoltuion, FP16, NCHW +## Build ckProfiler +```bash + make -j ckProfiler ``` -./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +Instructions for running ckProfiler are under ```profiler/``` -layout: 0 -in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1} -wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1} -out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {2, 2, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw -a_k0_m_k1_grid_desc{216, 256, 8} -b_k0_n_k1_grid_desc{216, 165888, 8} -c_m_n_grid_desc{ 256, 165888} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.4155 ms, 103.686 TFlop/s +## Install CK +```bash +make install ``` -Forward convoltuion, FP16, NCHW -``` - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 0 -in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1} -wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1} -out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw -a_k0_m_k1_grid_desc{288, 1024, 8} -b_k0_n_k1_grid_desc{288, 50176, 8} -c_m_n_grid_desc{ 1024, 50176} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 2.21357 ms, 106.959 TFlop/s - ``` - - Forward convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1} -wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1} -out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {2, 2, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{216, 165888, 8} -b_k0_n_k1_grid_desc{216, 256, 8} -c_m_n_grid_desc{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.12014 ms, 131.025 TFlop/s - ``` - - Forward convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} -wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1} -out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{288, 50176, 8} -b_k0_n_k1_grid_desc{288, 1024, 8} -c_m_n_grid_desc{ 50176, 1024} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.86877 ms, 126.693 TFlop/s - ``` - - Backward data convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} -wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1} -out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{288, 50176, 8} -b_k0_n_k1_grid_desc{288, 1024, 8} -c_m_n_grid_desc{ 50176, 1024} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 2.22461 ms, 106.428 TFlop/s -``` +## Using CK as pre-built kernel library +Instructions for using CK as a pre-built kernel library are under ```client_example/``` + +## Caveat +### Kernel Timing and Verification +CK's own kernel timer will warn up kernel once, and then run it multiple times +to get average kernel time. For some kernels that use atomic add, this will cause +output buffer to be accumulated multiple times, causing verfication failure. +To work around it, do not use CK's own timer and do verification at the same time. +CK's own timer and verification in each example and ckProfiler can be enabled or +disabled from command line. diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt new file mode 100644 index 00000000..9e741192 --- /dev/null +++ b/client_example/01_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm gemm.cpp) +target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations) diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp new file mode 100644 index 00000000..9b7b7a66 --- /dev/null +++ b/client_example/01_gemm/gemm.cpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 00000000..1064abc8 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) +target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp new file mode 100644 index 00000000..dbf2e634 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DDELayout = Row; +using DDELayout = Row; +using DELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, DDELayout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, DDELayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_matrix_space_size(M, N, StrideE, DELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + DDELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt new file mode 100644 index 00000000..3742e708 --- /dev/null +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp) +target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations) diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp new file mode 100644 index 00000000..8f142937 --- /dev/null +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using BiasDataType = F32; +using CDataType = F16; +using D0DataType = F16; +using ReduceDataType = F32; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op, + const void* p_a, + const void* p_b, + const void* p_bias, + const void* p_d0, + void* p_c, + void* p_mean, + void* p_square_mean, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideD0, + bool time_kernel) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto passOp = PassThrough{}; + auto squareOp = UnarySquareElementOp{}; + auto divOp = UnaryDivElementOp{N}; + + auto argument_ptr = + p_op->MakeArgumentPointer(p_a, + p_b, + p_bias, + {p_d0}, + p_c, + {p_mean, p_square_mean}, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + {&passOp, &passOp, &passOp}, // functor for a, b, c + {&passOp}, // functor for d0 + {&passOp, &squareOp}, // functor for inputs of reduction + {&divOp, &divOp}); // functor for outputs of reduction + + if(p_op->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = p_op->MakeInvokerPointer(); + + // If we evaluate running time of gemm_reduce. The output may wrong. + // Because we need to initialize the reduction tensor before runing the kernel. + // However we run kernel many times for time_kernel = trie without reinitialize the out + // of reduction tensor. + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl; + + return true; + } + + return false; +} + +template +bool RunDeviceNormalize2D(normalize_op_ptr& p_op, + const void* p_x, + const void* p_mean, + const void* p_square_mean, + const void* p_gamma, + const void* p_beta, + void* p_y, + int M, + int N, + int StrideX, + bool time_kernel) +{ + std::array input = {p_x, p_mean, p_square_mean, p_gamma, p_beta}; + std::array output = {p_y}; + auto normalize_functor = ck::tensor_operation::element_wise::Normalize{}; + + auto argument_ptr = p_op->MakeArgumentPointer(input, + output, + {M, N}, + {{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}}, + {{StrideX, 1}}, + ck::tensor_operation::element_wise::Normalize{}); + + if(p_op->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = p_op->MakeInvokerPointer(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl; + + return true; + } + + return false; +} + +int main() +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + ck::index_t StrideD0 = 1024; + + const auto gemm_reduce_ptrs = + ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances< + ADataType, + BDataType, + CDataType, + ALayout, + BLayout, + CLayout>(); + + const auto normalize_ptrs = + ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances< + CDataType, + ReduceDataType, + ReduceDataType, + GammaDataType, + BetaDataType, + LayerNormOutDataType>(); + + std::cout << "found " << gemm_reduce_ptrs.size() + << " gemm_reduceMean_reduceSquareMean instances" << std::endl; + + std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, CLayout{})); + SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M); + SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N); + + bool b_time_kernel = true; + bool b_only_run_first_kernel = true; + + // layernorm => (1) + (2) + // (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c) + // (2). normalize(c, mean, square_mean, gamma, beta) + for(auto& gemm_reduce_ptr : gemm_reduce_ptrs) + { + // run first available kernel + if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr, + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + d0_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideD0, + b_time_kernel)) + { + if(b_only_run_first_kernel) + break; + } + else + { + std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + + for(auto& normalize_ptr : normalize_ptrs) + { + if(RunDeviceNormalize2D(normalize_ptr, + c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + layerNorm_device_buf.GetDeviceBuffer(), + M, + N, + StrideC, + b_time_kernel)) + { + if(b_only_run_first_kernel) + break; + } + else + { + std::cout << normalize_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } +} diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt new file mode 100644 index 00000000..4bc6780f --- /dev/null +++ b/client_example/04_contraction/CMakeLists.txt @@ -0,0 +1,6 @@ +add_executable(client_contraction_scale contraction_scale.cpp) +target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_bilinear contraction_bilinear.cpp) +target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations) + diff --git a/client_example/04_contraction/contraction_bilinear.cpp b/client_example/04_contraction/contraction_bilinear.cpp new file mode 100644 index 00000000..b71c51c0 --- /dev/null +++ b/client_example/04_contraction/contraction_bilinear.cpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Bilinear; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 25) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])}; + + alpha = std::stof(argv[23]); + beta = std::stof(argv[24]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg23 to 24: alpha, beta\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem d_device_buf(sizeof(DDataType) * + f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/04_contraction/contraction_scale.cpp b/client_example/04_contraction/contraction_scale.cpp new file mode 100644 index 00000000..5908c1d8 --- /dev/null +++ b/client_example/04_contraction/contraction_scale.cpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp" + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Scale; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 20) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + scale = std::stof(argv[19]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg19: scale\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{scale}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt new file mode 100644 index 00000000..3e04a185 --- /dev/null +++ b/client_example/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.15) +project(ck_app) +add_compile_options(-std=c++17) + +find_package(composable_kernel 1.0.0 COMPONENTS device_operations) +find_package(hip REQUIRED PATHS /opt/rocm) +message(STATUS "Build with HIP ${hip_VERSION}") + +add_subdirectory(01_gemm) +add_subdirectory(02_gemm_add_add_fastgelu) +add_subdirectory(03_gemm_layernorm) +add_subdirectory(04_contraction) diff --git a/client_example/README.md b/client_example/README.md new file mode 100644 index 00000000..64a7130d --- /dev/null +++ b/client_example/README.md @@ -0,0 +1,21 @@ +## +Client application links to CK library, and therefore CK library needs to be installed before building client applications. + + +## Build +```bash +mkdir -p client_example/build +cd client_example/build +``` + +```bash +cmake \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \ +.. +``` + +### Build client example +```bash + make -j +``` diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 9f193b20..78133af0 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused - -Wno-sign-compare + -Wsign-compare -Wno-extra-semi-stmt ) if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") diff --git a/cmake/TargetFlags.cmake b/cmake/TargetFlags.cmake new file mode 100644 index 00000000..4f83fb5d --- /dev/null +++ b/cmake/TargetFlags.cmake @@ -0,0 +1,50 @@ + +function(get_target_property2 VAR TARGET PROPERTY) + get_target_property(_pflags ${TARGET} ${PROPERTY}) + if(_pflags) + set(${VAR} ${_pflags} PARENT_SCOPE) + else() + set(${VAR} "" PARENT_SCOPE) + endif() +endfunction() + + +macro(append_flags FLAGS TARGET PROPERTY PREFIX) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + else() + string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") + endif() + endforeach() +endmacro() + +macro(append_link_flags FLAGS TARGET PROPERTY) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + elseif(FLAG MATCHES "^-.*") + string(APPEND ${FLAGS} " ${FLAG}") + elseif(EXISTS ${FLAG}) + string(APPEND ${FLAGS} " ${FLAG}") + else() + string(APPEND ${FLAGS} " -l${FLAG}") + endif() + endforeach() +endmacro() + +function(target_flags FLAGS TARGET) + set(_flags) + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "") + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D") + append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "") + append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "") + # message("_flags: ${_flags}") + set(${FLAGS} ${_flags} PARENT_SCOPE) +endfunction() diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake new file mode 100644 index 00000000..3718b916 --- /dev/null +++ b/cmake/googletest.cmake @@ -0,0 +1,43 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +message(STATUS "Fetching GoogleTest") + +list(APPEND GTEST_CMAKE_CXX_FLAGS + -Wno-undef + -Wno-reserved-identifier + -Wno-global-constructors + -Wno-missing-noreturn + -Wno-disabled-macro-expansion + -Wno-used-but-marked-unused + -Wno-switch-enum + -Wno-zero-as-null-pointer-constant + -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast +) +message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") + +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG b85864c64758dec007208e56af933fc3f52044ee +) + +# Will be necessary for windows build +# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_GetProperties(googletest) +if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() + +target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp deleted file mode 100644 index 0a1e07ec..00000000 --- a/composable_kernel/include/gridwise_operation_wrapper.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER -#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - run_gridwise_operation(Xs... xs) -{ - GridwiseOp{}.Run(xs...); -} - -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp deleted file mode 100644 index 5cc2f239..00000000 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp +++ /dev/null @@ -1,183 +0,0 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP - -#include "common_header.hpp" -#include "threadwise_gemm_dlops_v3.hpp" - -namespace ck { - -template -struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 -{ - struct MatrixIndex - { - index_t k; - index_t h; - index_t w; - }; - - // HACK: fix this @Jing Zhang - static constexpr index_t KPerThreadSubC = 4; - - static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1>, - 1, - ThreadGemmADataPerRead_K, - 1>; - - __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() - : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, - a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} - { - static_assert(BlockMatrixA::IsKnownAtCompileTime() && - BlockMatrixB::IsKnownAtCompileTime() && - ThreadMatrixC::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), - "wrong! K dimension not consistent\n"); - - constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed - constexpr index_t H = BlockMatrixB{}.GetLength(I2); - constexpr index_t W = BlockMatrixB{}.GetLength(I3); - - static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, - "wrong! Cannot evenly divide work among\n"); - - constexpr auto KThreadCluster = K / KPerThread; - constexpr auto HThreadCluster = H / HPerThread; - constexpr auto WThreadCluster = W / WPerThread; - - static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, - "wrong! wrong blocksize\n"); - } - - __device__ static constexpr auto GetThreadMatrixCLengths() - { - return Sequence{}; - } - - __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) - { - constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); - constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); - - constexpr auto num_w_threads = W / WPerThread; - constexpr auto num_h_threads = H / HPerThread; - constexpr auto num_hw_threads = num_w_threads * num_h_threads; - - index_t k_thread_id = thread_id / num_hw_threads; - index_t hw_thread_id = thread_id % num_hw_threads; - - index_t h_thread_id = hw_thread_id / num_w_threads; - index_t w_thread_id = hw_thread_id % num_w_threads; - - return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BThreadBuffer& b_thread_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto I0 = Number<0>{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - - constexpr auto EPerBlock = a_block_mtx.GetLength(I0); - - // HACK: fix this @Jing Zhang - constexpr auto HoPerThreadSubC = 2; - constexpr auto WoPerThreadSubC = 2; - - static_assert(KPerThread % KPerThreadSubC == 0, ""); - static_assert(HPerThread % HoPerThreadSubC == 0, ""); - static_assert(WPerThread % WoPerThreadSubC == 0, ""); - - // thread A buffer for GEMM - StaticBuffer - a_thread_buf; - - constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; - - static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { - static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { - a_thread_copy_.Run(a_block_mtx, - make_tuple(e_begin, k_begin), - a_block_buf, - a_thread_mtx_, - make_tuple(I0, I0), - a_thread_buf); - - static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) { - static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) { - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0), - b_thread_buf, - make_tuple(e_begin, I0, h_begin, w_begin), - c_thread_buf, - make_tuple(k_begin, I0, h_begin, w_begin)); - }); - }); - }); - }); - } - - template - __device__ void MoveASliceWindow(const BlockMatrixA&, - const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) - { - a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); - } - - private: - MatrixIndex c_thread_begin_mtx_idx_; - - AThreadCopy a_thread_copy_; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp deleted file mode 100644 index 36c67832..00000000 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ /dev/null @@ -1,282 +0,0 @@ -#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP -#define CK_BLOCKWISE_GEMM_XDLOPS_HPP - -#include "common_header.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "xdlops_gemm.hpp" -#include "tensor_adaptor.hpp" - -namespace ck { - -template -struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr index_t WaveSize = 64; - - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - - static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); - - static constexpr auto xdlops_gemm = XdlopsGemm{}; - - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - - StaticBufferV2, MRepeat * NRepeat, true> - c_thread_buf_; - - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - - __device__ static auto GetWaveIdx() - { - const index_t thread_id = get_thread_local_1d_id(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); - - return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0); - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); - - return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0); - } - - template - __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; - const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; - - return make_tuple(c_thread_m, c_thread_n); - } - - __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() - { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), - "wrong! K0 dimension not consistent"); - - static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2), - "wrong! K1 dimension not consistent"); - - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); - - static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, - "wrong!"); - } - - __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() - { - constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); - - constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; - constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; - constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; - constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; - - return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); - } - - __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() - { - constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); - } - - template - __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); - - return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); - } - - __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); - static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, - make_tuple(I0, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, - make_tuple(I0, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - b_thread_buf); - - static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, K1, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0)); - - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVector(Number{})); - }); - }); - }); - } - - private: - // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - - // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - K1, - K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - K1, - K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp deleted file mode 100644 index 0214b713..00000000 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp +++ /dev/null @@ -1,170 +0,0 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -// this version does following things to avoid scratch memory issue -// 1. Use StaticallyIndexedArray instead of C array for thread buffer -// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor -// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v4 -{ - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - - using Index = MultiIndex; - - __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin) - : threadwise_transfer_( - src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) - - { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && - nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && - nDim == ThreadClusterLengths::Size() && - nDim == ThreadClusterArrangeOrder::Size() && - nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), - "wrong! nDim not consistent"); - - static_assert( - is_same{}, - "wrong! threads should be mapped to cover entire slicing window"); - - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); - - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); - - const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; - - threadwise_transfer_.SetSrcSliceOrigin(src_desc, - src_block_slice_origin + thread_data_idx_begin); - threadwise_transfer_.SetDstSliceOrigin(dst_desc, - dst_block_slice_origin + thread_data_idx_begin); - } - } - - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); - } - } - - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf); - } - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunWrite(dst_desc, dst_buf); - } - } - - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); - } - } - - // SrcMoveSliceWindowStepHack to control index calculation move slice window - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& step, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_step_hack); - } - } - - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); - } - } - - private: - static constexpr auto thread_cluster_desc_ = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - - using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3; - - ThreadwiseTransfer threadwise_transfer_; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp deleted file mode 100644 index 2653dd43..00000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp +++ /dev/null @@ -1,650 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_V1R3_HPP -#define CK_GRIDWISE_GEMM_V1R3_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by CONSTANT void pointer -// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, - const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - // first cast void CONSTANT void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc)); - const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc)); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} -#endif - -template -struct GridwiseGemmDlops_km_kn_mn_v1r3 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = K1; - - // TODO: check alignment - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // TODO: check alignment - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // TODO: check alignment - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); - } - - __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) - { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) - { - const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) - { - const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) - { - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - - const auto M1 = Number{}; - const auto M0 = M / M1; - - const auto a_k0_m0_m1_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc, - make_tuple(make_pass_through_transform(K0), - make_unmerge_transform(make_tuple(M0, M1)), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return a_k0_m0_m1_k1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) - { - const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - - const auto N1 = Number{}; - const auto N0 = N / N1; - - const auto b_k0_n0_n1_k1_grid_desc = - transform_tensor_descriptor(b_k0_n_k1_grid_desc, - make_tuple(make_pass_through_transform(K0), - make_unmerge_transform(make_tuple(N0, N1)), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return b_k0_n0_n1_k1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - constexpr auto M11 = - Number{}; - constexpr auto N11 = - Number{}; - - constexpr auto M10 = M1 / M11; - constexpr auto N10 = N1 / N11; - - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), - make_unmerge_transform(make_tuple(N0, N10, N11))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return c_m0_m10_m11_n0_n10_n11_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); - using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); - - // divide block work by [M, N] - const auto c_m0_n0_block_cluster_idx = - c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); - const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); - - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = K1; - - // TODO: check alignment - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); - - // TODO: check alignment - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); - - // TODO: check alignment - // A matrix in LDS memory, for blockwise GEMM - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - // TODO: check alignment - // B matrix in LDS memory, for blockwise GEMM - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - - static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == - a_k0_m_k1_block_desc.GetElementSpaceSize() && - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == - b_k0_n_k1_block_desc.GetElementSpaceSize() && - "wrong!"); - - // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< - BlockSize, - InMemoryDataOperationEnum_t::Set, - Sequence, - ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m0_m1_k1_grid_desc), - decltype(a_k0_m0_m1_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3>, - ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths - ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths - ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder - false, - true>(a_k0_m0_m1_k1_grid_desc, - make_multi_index(0, im0, 0, 0), - a_k0_m0_m1_k1_block_desc, - make_multi_index(0, 0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< - BlockSize, - InMemoryDataOperationEnum_t::Set, - Sequence, - BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n0_n1_k1_grid_desc), - decltype(b_k0_n0_n1_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3>, - BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths - BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths - BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder - false, - true>(b_k0_n0_n1_k1_grid_desc, - make_multi_index(0, in0, 0, 0), - b_k0_n0_n1_k1_block_desc, - make_multi_index(0, 0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - decltype(a_k0_m_k1_block_desc), - decltype(b_k0_n_k1_block_desc), - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - - constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); - - ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); - - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K0 - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = - blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id()); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), - Sequence<1, - c_m10_m11_n10_n11_thread_tensor_lengths[I0], - c_m10_m11_n10_n11_thread_tensor_lengths[I1], - 1, - c_m10_m11_n10_n11_thread_tensor_lengths[I2], - c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, - make_multi_index(im0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], - in0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp deleted file mode 100644 index 86e047c9..00000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ /dev/null @@ -1,639 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0MK1GridDesc a_k0_m_k1_grid_desc, - const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const CBlockClusterAdaptor c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - const auto a_k0_m_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); - const auto b_k0_n_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); - const auto c_block_cluster_adaptor = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); -} -#endif - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_m0_n0_block_cluster_adaptor; - } - - using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const CBlockClusterAdaptor& c_block_cluster_adaptor) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); - - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m_k1_grid_desc), - decltype(a_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_k0_m_k1_grid_desc, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_k0_m_k1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n_k1_grid_desc), - decltype(b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_k0_n_k1_grid_desc, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - } - - // main body - index_t k0_block_data_begin = 0; - - if constexpr(HasMainKBlockLoop) - { - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); - - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - - block_sync_lds(); - - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - - constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); - constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); - constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); - constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); - constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); - constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); - constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); - constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])}; - - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - } - } -}; // namespace ck - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp deleted file mode 100644 index c635da57..00000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp +++ /dev/null @@ -1,625 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_blockwise.hpp" - -#include "blockwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_blockwise -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = - typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = - typename reduce_unary_operator::posUnaryOp; - - static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - using blockwise_reduce = - BlockwiseReduction_2d_block_buffer; - - static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize(); - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - (void)indices_global; - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - - auto in_block_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - const posUnaryOpType posUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - constexpr auto in_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_buf); - - __syncthreads(); - - // do element-wise pre-reduction operation - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0)); - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]); - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - threadwise_dst_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf); - } - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - __shared__ int block_indices_buffer[BlockBufferSize]; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - auto in_block_val_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - auto in_block_idx_buf = - make_dynamic_buffer(block_indices_buffer, BlockBufferSize); - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - constexpr auto in_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize; - - int indexOffset = 0; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - // load block data from global to LDS, no use of double buffers (to be improved) - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf); - - __syncthreads(); - - // construct the indices for the current toReduce blocks - blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - - blockwise_reduce::Reduce2(in_block_val_buf, - in_block_idx_buf, - BlocksInOneOp, - accuValue_buf(I0), - accuIndex_buf(I0)); - - indexOffset += BlockBufferSize; - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; - - template <> - __device__ static void Run<3>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ ws_values_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)origReduceLen; - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - __shared__ int block_indices_buffer[BlockBufferSize]; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - ws_indices_global, src2dDesc.GetElementSpaceSize()); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - auto in_block_val_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - auto in_block_idx_buf = - make_dynamic_buffer(block_indices_buffer, BlockBufferSize); - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - constexpr auto in_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_val_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - auto blockwise_src_idx_load = - BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - int, - int, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>(src2dDesc, - make_multi_index(block_global_1d_id, 0), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - // load block data from global to LDS, no use of double buffers (to be improved) - blockwise_src_val_load.RunRead(src2dDesc, src_global_val_buf); - blockwise_src_idx_load.RunRead(src2dDesc, src_global_idx_buf); - blockwise_src_val_load.RunWrite(in_block_desc, in_block_val_buf); - blockwise_src_idx_load.RunWrite(in_block_desc, in_block_idx_buf); - - __syncthreads(); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - - blockwise_reduce::Reduce2(in_block_val_buf, - in_block_idx_buf, - BlocksInOneOp, - accuValue_buf(I0), - accuIndex_buf(I0)); - - blockwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - blockwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(block_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp deleted file mode 100644 index adfeacc0..00000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp +++ /dev/null @@ -1,503 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_threadwise.hpp" - -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_direct_threadwise -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = - typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = - typename reduce_unary_operator::posUnaryOp; - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - (void)indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using threadwise_reduce = ThreadReduce; - - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - const posUnaryOpType posUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += GredThreadBufferLength) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // do element-wise pre-reduction operation - threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the reduction on the Thread Buffer - threadwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0)); - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]); - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>( - dst1dDesc, make_multi_index(thread_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - threadwise_dst_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf); - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using threadwise_reduce = ThreadReduce; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength); - - index_t indexStart = 0; - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += GredThreadBufferLength) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the reduction on the Thread Buffer - threadwise_reduce::Reduce2( - in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexStart); - - indexStart += GredThreadBufferLength; - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>( - dst1dDesc, make_multi_index(thread_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - }; - - template <> - __device__ static void Run<3>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ ws_values_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)origReduceLen; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - ws_indices_global, src2dDesc.GetElementSpaceSize()); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_val_buf; - StaticBuffer in_thread_idx_buf; - - using threadwise_reduce = ThreadReduceWithIndicesInput; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - - using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - - auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, make_multi_index(thread_global_1d_id, 0)); - - constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += GredThreadBufferLength) - { - threadwise_src_val_load.Run(src2dDesc, - src_global_val_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_val_buf); - threadwise_src_idx_load.Run(src2dDesc, - src_global_idx_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_idx_buf); - - // do the reduction on the Thread Buffer - threadwise_reduce::Reduce( - in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0)); - - threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>( - dst1dDesc, make_multi_index(thread_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - false>(dst1dDesc, - make_multi_index(thread_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp deleted file mode 100644 index 4136dae7..00000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp +++ /dev/null @@ -1,544 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_warpwise.hpp" - -#include "threadwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_direct_warpwise -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = - typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = - typename reduce_unary_operator::posUnaryOp; - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - (void)indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using warpwise_reduce = - WarpReduce; - - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - const posUnaryOpType posUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - index_t warp_global_1d_id = thread_global_1d_id / warpSize; - index_t thread_inwarp_id = thread_global_1d_id % warpSize; - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - constexpr auto in_thread_copy_step = - make_multi_index(0, warpSize * GredAccessesPerThreadInWarp); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += warpSize * GredAccessesPerThreadInWarp) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // do element-wise pre-reduction operation - warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the warp-wise reduction on data of all thread buffers - warpwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0)); - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]); - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the warp stores the reduced result to the global location - // representing the Warp - if(thread_inwarp_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run( - dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf(I0) * beta; - } - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - threadwise_dst_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf); - } - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)ws_indices_global; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_buf; - - using warpwise_reduce = - WarpReduce; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - index_t warp_global_1d_id = thread_global_1d_id / warpSize; - index_t thread_inwarp_id = thread_global_1d_id % warpSize; - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - constexpr auto in_thread_copy_step = - make_multi_index(0, warpSize * GredAccessesPerThreadInWarp); - - index_t indexOffset = 0; - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += warpSize * GredAccessesPerThreadInWarp) - { - threadwise_src_load.Run( - src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf); - - // do the warp-wise reduction on data of all thread buffers - warpwise_reduce::Reduce2( - in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexOffset); - - indexOffset += warpSize * GredAccessesPerThreadInWarp; - - threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the warp stores the reduced result to the global location - // representing the Warp - if(thread_inwarp_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; - - template <> - __device__ static void Run<3>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - srcDataType alpha, - const srcDataType* const __restrict__ ws_values_global, - dstDataType beta, - dstDataType* const __restrict__ p_dst_global, - const int* const __restrict__ ws_indices_global, - int* const __restrict__ indices_global) - { - (void)origReduceLen; - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - ws_indices_global, src2dDesc.GetElementSpaceSize()); - auto dst_global_val_buf = make_dynamic_buffer( - p_dst_global, dst1dDesc.GetElementSpaceSize()); - auto dst_global_idx_buf = make_dynamic_buffer( - indices_global, dst1dDesc.GetElementSpaceSize()); - - StaticBuffer - in_thread_val_buf; - StaticBuffer - in_thread_idx_buf; - - using warpwise_reduce = WarpReduceWithIndicesInput; - - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - - using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>; - constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - index_t warp_global_1d_id = thread_global_1d_id / warpSize; - index_t thread_inwarp_id = thread_global_1d_id % warpSize; - - auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - false>( - src2dDesc, - make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp)); - - constexpr auto in_thread_copy_step = - make_multi_index(0, warpSize * GredAccessesPerThreadInWarp); - - for(index_t reducedLength = 0; reducedLength < toReduceLength; - reducedLength += warpSize * GredAccessesPerThreadInWarp) - { - threadwise_src_val_load.Run(src2dDesc, - src_global_val_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_val_buf); - threadwise_src_idx_load.Run(src2dDesc, - src_global_idx_buf, - ThreadBufferDesc, - make_tuple(I0, I0), - in_thread_idx_buf); - - // do the warp-wise reduction on data of all thread buffers - warpwise_reduce::Reduce( - in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0)); - - threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - // The first thread in the warp stores the reduced result to the global location - // representing the Warp - if(thread_inwarp_id == 0) - { - if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); - - StaticBuffer dstValue_buf; - - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); - - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - StaticBuffer priorDstValue_buf; - - threadwise_dst_load.Run(dst1dDesc, - dst_global_val_buf, - ReducedDataDesc, - make_tuple(I0), - priorDstValue_buf); - - dstValue_buf(I0) += priorDstValue_buf[I0] * beta; - } - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(dst1dDesc, - make_multi_index(warp_global_1d_id)); - - threadwise_dst_val_store.Run( - ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf); - threadwise_dst_idx_store.Run( - ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp deleted file mode 100644 index feee2b59..00000000 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp +++ /dev/null @@ -1,376 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP -#define CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_blockwise.hpp" - -#include "blockwise_tensor_slice_transfer.hpp" - -namespace ck { - -template -struct GridwiseReduction_xy_to_x_multiblock -{ - using opReduce = typename reduce_binary_operator::opType; - using preUnaryOpType = typename reduce_unary_operator::preUnaryOp; - using posUnaryOpType = typename reduce_unary_operator::posUnaryOp; - - static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - using blockwise_reduce = - BlockwiseReduction_2d_block_buffer; - - static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize(); - - static constexpr auto I0 = Number<0>{}; - - template - __device__ static void Run(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - int BlkGroupSize, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - srcDataType* const __restrict__ ws_values_global, - int* const __restrict__ ws_indices_global); - - template <> - __device__ static void Run<1>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - int BlkGroupSize, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - srcDataType* const __restrict__ ws_values_global, - int* const __restrict__ ws_indices_global) - { - (void)ws_indices_global; - - (void)alpha; // unused - (void)beta; // unused - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - // LDS - __shared__ compType p_in_block_buffer[BlockBufferSize]; - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto workspace_global_buf = make_dynamic_buffer( - ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); - - auto in_block_buf = - make_dynamic_buffer(p_in_block_buffer, BlockBufferSize); - StaticBuffer accuValue_buf; - - accuValue_buf(I0) = zeroVal; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / BlkGroupSize; - const index_t block_local_id = block_global_id % BlkGroupSize; - - const index_t reduceSizePerBlock = - (((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) / - BlockBufferSize) * - BlockBufferSize; - - constexpr auto in_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>( - src2dDesc, - make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_buf); - __syncthreads(); - - // do element-wise pre-reduction operation - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0)); - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - const auto workspace_desc = - make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize)); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - auto threadwise_workspace_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(workspace_desc, - make_multi_index(block_global_id)); - - threadwise_workspace_store.Run(ReducedDataDesc, - make_tuple(I0), - accuValue_buf, - workspace_desc, - workspace_global_buf); - } - }; - - template <> - __device__ static void Run<2>(const src2dDescType& src2dDesc, - const dst1dDescType& dst1dDesc, - int origReduceLen, - int BlkGroupSize, - srcDataType alpha, - const srcDataType* const __restrict__ p_src_global, - dstDataType beta, - srcDataType* const __restrict__ ws_values_global, - int* const __restrict__ ws_indices_global) - { - (void)alpha; // unused - (void)beta; // unused - - const auto zeroVal = opReduce::GetReductionZeroVal(); - - // LDS - __shared__ compType p_in_block_values_buffer[BlockBufferSize]; - __shared__ int p_in_block_indices_buffer[BlockBufferSize]; - - const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); - auto workspace_global_val_buf = make_dynamic_buffer( - ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); - auto workspace_global_idx_buf = make_dynamic_buffer( - ws_indices_global, dst1dDesc.GetLength(I0) * BlkGroupSize); - - auto in_block_val_buf = - make_dynamic_buffer(p_in_block_values_buffer, BlockBufferSize); - auto in_block_idx_buf = make_dynamic_buffer( - p_in_block_indices_buffer, BlockBufferSize); - StaticBuffer accuValue_buf; - StaticBuffer accuIndex_buf; - - accuValue_buf(I0) = zeroVal; - accuIndex_buf(I0) = 0; - - const auto toReduceLength = src2dDesc.GetLength(Number<1>{}); - const int divider = origReduceLen; - - const preUnaryOpType preUnaryOp(divider); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / BlkGroupSize; - const index_t block_local_id = block_global_id % BlkGroupSize; - - const index_t reduceSizePerBlock = - (((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) / - BlockBufferSize) * - BlockBufferSize; - - constexpr auto in_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number<1>{}, Number{})); - - using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>; - using ThreadClusterLengths = Sequence<1, BlockSize>; - - auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4, - ThreadSliceLengths, - ThreadClusterLengths, - Sequence<0, 1>, - srcDataType, - compType, - src2dDescType, - decltype(in_block_desc), - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - 1, - 1, - 1, - 1, - false, - true>( - src2dDesc, - make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock), - in_block_desc, - make_multi_index(0, 0)); - - constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize); - - const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize; - - int indexOffset = block_local_id * reduceSizePerBlock; - - for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks; - reducedBlocks += GredAccessesPerThreadInBlock) - { - blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset); - - blockwise_src_load.RunRead(src2dDesc, src_global_buf); - blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf); - - __syncthreads(); - - // unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually - // done here - blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf); - - index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock) - ? GredAccessesPerThreadInBlock - : toReduceBlocks - reducedBlocks; - - blockwise_reduce::Reduce2(in_block_val_buf, - in_block_idx_buf, - BlocksInOneOp, - accuValue_buf(I0), - accuIndex_buf(I0)); - - indexOffset += BlockBufferSize; - - blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step); - } - - constexpr auto ReducedDataDesc = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); - - const auto workspace_desc = - make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize)); - - // The first thread in the block stores the reduced result to the global location - // representing the block - if(thread_local_id == 0) - { - auto threadwise_workspace_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(workspace_desc, - make_multi_index(block_global_id)); - - auto threadwise_workspace_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - 1, - InMemoryDataOperationEnum_t::Set, - 1, - true>(workspace_desc, - make_multi_index(block_global_id)); - - threadwise_workspace_val_store.Run(ReducedDataDesc, - make_tuple(I0), - accuValue_buf, - workspace_desc, - workspace_global_val_buf); - threadwise_workspace_idx_store.Run(ReducedDataDesc, - make_tuple(I0), - accuIndex_buf, - workspace_desc, - workspace_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp deleted file mode 100644 index 046d3311..00000000 --- a/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp +++ /dev/null @@ -1,271 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP -#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_binop.hpp" - -namespace ck { - -template -struct BlockwiseReduction_2d_block_buffer -{ - using compType = typename opReduce::dataType; - - static constexpr auto buffer2dDesc = buffer2dDescType{}; - - static constexpr index_t BlockSize = - blockIsOneRow ? buffer2dDesc.GetLength(Number<1>{}) : buffer2dDesc.GetLength(Number<0>{}); - static constexpr index_t NumBlocks = - blockIsOneRow ? buffer2dDesc.GetLength(Number<0>{}) : buffer2dDesc.GetLength(Number<1>{}); - using binop = detail::binop_with_nan_check; - - // This interface does not accumulate on indices - template - __device__ static void - Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData) - { - const index_t thread_local_id = get_thread_local_1d_id(); - compType lAccuData = opReduce::GetReductionZeroVal(); - - index_t offset; - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); - compType opData = type_convert{}(block_buffer[offset]); - - binop::calculate(lAccuData, opData); - } - - offset = blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - - block_buffer(offset) = lAccuData; - - __syncthreads(); - - for(index_t indOffset = BlockSize / 2; indOffset > 0; indOffset /= 2) - { - if(thread_local_id < indOffset) - { - index_t offset1 = - blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - - index_t offset2 = - blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); - - compType opData1 = type_convert{}(block_buffer[offset1]); - compType opData2 = type_convert{}(block_buffer[offset2]); - binop::calculate(opData1, opData2); - block_buffer(offset1) = type_convert{}(opData1); - } - - __syncthreads(); - } - - if(thread_local_id == 0) - { - compType tmpVal = type_convert{}(block_buffer[0]); - - binop::calculate(accuData, tmpVal); - } - }; - - // This interface accumulates on both data values and indices - template - __device__ static void Reduce2(BufferType& block_buffer, - IdxBufferType& block_indices_buffer, - index_t toReduceBlocks, - compType& accuData, - int& accuIndex) - { - const index_t thread_local_id = get_thread_local_1d_id(); - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - - if constexpr(blockIsOneRow) - { - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2) - { - if(thread_local_id % (indOffset * 2) == 0) - { - index_t offset1 = - buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id)); - index_t offset2 = buffer2dDesc.CalculateOffset( - make_tuple(otherDimInd, thread_local_id + indOffset)); - - compType currVal1 = type_convert{}(block_buffer[offset1]); - compType currVal2 = type_convert{}(block_buffer[offset2]); - int currIndex1 = block_indices_buffer[offset1]; - int currIndex2 = block_indices_buffer[offset2]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - block_buffer(offset1) = type_convert{}(currVal1); - block_indices_buffer(offset1) = currIndex1; - } - __syncthreads(); - } - } - - if(thread_local_id == 0) - { - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0)); - - compType tmpVal = type_convert{}(block_buffer[offset]); - int tmpIndex = block_indices_buffer[offset]; - - binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); - } - - binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex); - } - } - else - { - index_t offset; - - for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) - { - offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); - compType currVal = type_convert{}(block_buffer[offset]); - int currIndex = block_indices_buffer[offset]; - - binop::calculate(lAccuData, currVal, lAccuIndex, currIndex); - } - - offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - - block_buffer(offset) = lAccuData; - block_indices_buffer(offset) = lAccuIndex; - - __syncthreads(); - - for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2) - { - if(thread_local_id % (indOffset * 2) == 0) - { - index_t offset1 = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0)); - index_t offset2 = - buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); - - compType currVal1 = type_convert{}(block_buffer[offset1]); - compType currVal2 = type_convert{}(block_buffer[offset2]); - int currIndex1 = block_indices_buffer[offset1]; - int currIndex2 = block_indices_buffer[offset2]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - block_buffer(offset1) = type_convert{}(currVal1); - block_indices_buffer(offset1) = currIndex1; - } - - __syncthreads(); - } - - if(thread_local_id == 0) - { - compType tmpVal = type_convert{}(block_buffer[0]); - int tmpIndex = block_indices_buffer[0]; - - binop::calculate(accuData, tmpVal, accuIndex, tmpIndex); - } - } - }; - - template - __device__ static void set_buffer_value(BufferType& block_buffer, compType value) - { - index_t thread_id = get_thread_local_1d_id(); - - for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++) - { - index_t offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd)); - - block_buffer(offset) = value; - - __syncthreads(); - } - }; - - // Initialize the block-wise indices buffer, the index for each element in the block-wise data - // buffer - // is calculated according to its position in the buffer and the global starting index - template - __device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart) - { - index_t thread_id = get_thread_local_1d_id(); - - for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++) - { - index_t offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd)); - - block_indices_buffer(offset) = offset + indexStart; - - __syncthreads(); - } - }; - - // Execute unary operation on the block buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& block_buffer) - { - index_t thread_id = get_thread_local_1d_id(); - - for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++) - { - index_t offset = blockIsOneRow - ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id)) - : buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd)); - - block_buffer(offset) = unary_op(block_buffer[offset]); - - __syncthreads(); - } - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp deleted file mode 100644 index 2956606a..00000000 --- a/composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp +++ /dev/null @@ -1,141 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP -#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_binop.hpp" - -namespace ck { - -template -struct ThreadReduce -{ - using compType = typename opReduce::dataType; - - static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!"); - - static_assert( - std::is_same::value, - "Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - - using binop = detail::binop_with_nan_check; - - // This interface does not accumulate on indices - __device__ static void Reduce(const BufferType& thread_buffer, compType& accuData) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { binop::calculate(accuData, thread_buffer[I]); }); - }; - - // This interface accumulates on both data values and indices and - // is called by Direct_ThreadWise reduction method at first-time reduction - __device__ static void - Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - int currIndex = I + indexStart; - binop::calculate(accuData, thread_buffer[I], accuIndex, currIndex); - }); - }; - - // Set the elements in the per-thread buffer to a specific value - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - }; -}; - -template -struct ThreadReduceWithIndicesInput -{ - using compType = typename opReduce::dataType; - - static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!"); - static_assert(IdxBufferType::IsStaticBuffer(), - "Thread-wise reduction needs use StaticBuffer for indices!"); - - static_assert( - std::is_same::value, - "Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"); - static_assert(std::is_same::value, - "Indices type of StaticBuffer for Thread-wise reduction should be index_t!"); - - static_assert(BufferType::Size() == IdxBufferType::Size(), - "StaticBuffers for data and indices should have the same sizes!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - - using binop = detail::binop_with_nan_check; - - // This interface accumulates on both data values and indices and - // is called by Direct_ThreadWise reduction method at second-time reduction - __device__ static void Reduce(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - binop::calculate(accuData, thread_buffer[I], accuIndex, thread_indices_buffer[I]); - }); - }; - - // Set the elements in the per-thread buffer to a specific value - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp deleted file mode 100644 index 9687d2d8..00000000 --- a/composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp +++ /dev/null @@ -1,371 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_WARPWISE_HPP -#define CK_REDUCTION_FUNCTIONS_WARPWISE_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_binop.hpp" - -namespace ck { - -template -struct WarpReduce -{ - using compType = typename opReduce::dataType; - using binop = detail::binop_with_nan_check; - - static_assert(BufferType::IsStaticBuffer(), - "Per-thread buffer for WarpWise reduction should be StaticBuffer!"); - static_assert(std::is_same::value, - "Data type of per-thread StaticBuffer for WarpWise reduction should be same as " - "the compType!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - static constexpr bool have_builtin_shuffle = - std::is_same::value || std::is_same::value; - - // This interface does not accumulate on indices - __device__ static void Reduce(const BufferType& thread_buffer, compType& accuData) - { - if constexpr(have_builtin_shuffle) - ReduceImpl1(thread_buffer, accuData); - else - ReduceImpl2(thread_buffer, accuData); - }; - - // This interface implementation uses HIP built-in device shuffling functions - __device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); - - // synchronize among all threads in this warp - __all(1); - - for(index_t stride = warpSize / 2; stride > 0; stride /= 2) - { - compType tmpVal = __shfl_down(lAccuData, stride, warpSize); - binop::calculate(lAccuData, tmpVal); - __all(1); - } - - binop::calculate(accuData, lAccuData); - }; - - // This interface implementation does not use HIP built-in device shuffling functions - // since for fp16, built-in shuffling functions is not provided by HIP - __device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); - - __syncthreads(); - - index_t thread_id = get_thread_local_1d_id(); - index_t warpId = thread_id / warpSize; - index_t thread_inwarp_id = thread_id % warpSize; - - __shared__ compType shuffle_buffer[BlockSize]; - - compType* myBuffer = &shuffle_buffer[warpId * warpSize]; - - myBuffer[thread_inwarp_id] = lAccuData; - - __syncthreads(); - - for(index_t stride = warpSize / 2; stride > 0; stride /= 2) - { - if(thread_inwarp_id < stride) - { - compType currVal1 = myBuffer[thread_inwarp_id]; - compType currVal2 = myBuffer[thread_inwarp_id + stride]; - - binop::calculate(currVal1, currVal2); - - myBuffer[thread_inwarp_id] = currVal1; - } - - __syncthreads(); - } - if(thread_inwarp_id == 0) - binop::calculate(accuData, myBuffer[0]); - }; - - // This interface accumulates on both data values and indices and is called by Direct_WarpWise - // reduction method at first-time reduction - __device__ static void - Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart) - { - if constexpr(have_builtin_shuffle) - Reduce2Impl1(thread_buffer, accuData, accuIndex, indexStart); - else - Reduce2Impl2(thread_buffer, accuData, accuIndex, indexStart); - }; - - // This interface implementation uses HIP built-in device shuffling functions - __device__ static void Reduce2Impl1(const BufferType& thread_buffer, - compType& accuData, - int& accuIndex, - int indexStart) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart; - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex); - }); - - // synchronize among all threads in this warp - __all(1); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType tmpVal = __shfl_down(lAccuData, stride, warpSize); - int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize); - - binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); - __all(1); - } - - if(thread_inwarp_id == 0) - binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex); - }; - - // This interface implementation does not use HIP built-in device shuffling functions since for - // fp16, built-in shuffling functions is not provided by HIP - __device__ static void Reduce2Impl2(const BufferType& thread_buffer, - compType& accuData, - int& accuIndex, - int indexStart) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - index_t thread_id = get_thread_local_1d_id(); - index_t warpId = thread_id / warpSize; - index_t thread_inwarp_id = thread_id % warpSize; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart; - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex); - }); - - __shared__ compType shuffle_data_buffer[BlockSize]; - __shared__ int shuffle_indices_buffer[BlockSize]; - - compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize]; - int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize]; - - myDataBuffer[thread_inwarp_id] = lAccuData; - myIndicesBuffer[thread_inwarp_id] = lAccuIndex; - - __syncthreads(); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType currVal1 = myDataBuffer[thread_inwarp_id]; - compType currVal2 = myDataBuffer[thread_inwarp_id + stride]; - int currIndex1 = myIndicesBuffer[thread_inwarp_id]; - int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - - myDataBuffer[thread_inwarp_id] = currVal1; - myIndicesBuffer[thread_inwarp_id] = currIndex1; - - __syncthreads(); - } - - if(thread_inwarp_id == 0) - binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]); - }; - - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - - __all(1); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - - __all(1); - }; -}; - -template -struct WarpReduceWithIndicesInput -{ - using compType = typename opReduce::dataType; - using binop = detail::binop_with_nan_check; - - static_assert(BufferType::IsStaticBuffer(), - "Per-thread buffer for WarpWise reduction should be StaticBuffer!"); - static_assert(IdxBufferType::IsStaticBuffer(), - "Per-thread buffer for WarpWise reduction should be StaticBuffer for indices!"); - - static_assert(std::is_same::value, - "Data type of per-thread StaticBuffer for WarpWise reduction should be same as " - "the compType!"); - static_assert( - std::is_same::value, - "Indices type per-thread of StaticBuffer for WarpWise reduction should be index_t!"); - - static_assert(BufferType::Size() == IdxBufferType::Size(), - "StaticBuffers for data and indices should have the same sizes!"); - - static constexpr index_t ThreadBufferLen = BufferType::Size(); - static constexpr bool have_builtin_shuffle = - std::is_same::value || std::is_same::value; - - // This interface accumulates on both data values and indices and is called by Direct_WarpWise - // reduction method at second-time reduction - __device__ static void Reduce(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - if constexpr(have_builtin_shuffle) - ReduceImpl1(thread_buffer, thread_indices_buffer, accuData, accuIndex); - else - ReduceImpl2(thread_buffer, thread_indices_buffer, accuData, accuIndex); - }; - - // This interface implementation uses HIP built-in device shuffling functions - __device__ static void ReduceImpl1(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]); - }); - - // synchronize among all threads in this warp - __all(1); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType tmpVal = __shfl_down(lAccuData, stride, warpSize); - int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize); - - binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); - __all(1); - } - - binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex); - }; - - // This interface implementation does not use HIP built-in device shuffling functions - // since for fp16, built-in shuffling functions is not provided by HIP - __device__ static void ReduceImpl2(const BufferType& thread_buffer, - const IdxBufferType& thread_indices_buffer, - compType& accuData, - int& accuIndex) - { - compType lAccuData = opReduce::GetReductionZeroVal(); - int lAccuIndex = 0; - index_t thread_id = get_thread_local_1d_id(); - index_t warpId = thread_id / warpSize; - index_t thread_inwarp_id = thread_id % warpSize; - - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { - binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]); - }); - - __shared__ compType shuffle_data_buffer[BlockSize]; - __shared__ int shuffle_indices_buffer[BlockSize]; - - compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize]; - int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize]; - - myDataBuffer[thread_inwarp_id] = lAccuData; - myIndicesBuffer[thread_inwarp_id] = lAccuIndex; - - __syncthreads(); - - for(index_t stride = 1; stride < warpSize; stride *= 2) - { - compType currVal1 = myDataBuffer[thread_inwarp_id]; - compType currVal2 = myDataBuffer[thread_inwarp_id + stride]; - int currIndex1 = myIndicesBuffer[thread_inwarp_id]; - int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride]; - - binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - - myDataBuffer[thread_inwarp_id] = currVal1; - myIndicesBuffer[thread_inwarp_id] = currIndex1; - - __syncthreads(); - } - - if(thread_inwarp_id == 0) - binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]); - }; - - // cppcheck-suppress constParameter - __device__ static void set_buffer_value(BufferType& thread_buffer, compType value) - { - static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; }); - - __all(1); - }; - - // Execute unary operation on the per-thread buffer elements - template - __device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer) - { - static_for<0, ThreadBufferLen, 1>{}( - [&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); }); - - __all(1); - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp deleted file mode 100644 index f6c15fd8..00000000 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp +++ /dev/null @@ -1,157 +0,0 @@ -#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP -#define CK_THREADWISE_GEMM_DLOPS_V3_HPP - -#include "common_header.hpp" -#include "math.hpp" - -namespace ck { - -// C[M, N] += transpose(A[K, M]) * B[K, N] -// Element of matrix can be vectorized data -// Assume: -// 1. ADesc, BDesc, CDesc are known at compile-time -// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time -template ::type = false> -struct ThreadwiseGemmDlops_km_kn_mn_v3 -{ - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && - CDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - constexpr auto E = ADesc{}.GetLength(I0); - constexpr auto K = ADesc{}.GetLength(I1); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, E, 1>{}([&](auto e) { - static_for<0, K, 1>{}([&](auto k) { - constexpr index_t a_offset = - ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k)); - - if constexpr(H == 2 && W == 2) - { - constexpr index_t b_offset_0 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); - constexpr index_t b_offset_1 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1)); - constexpr index_t b_offset_2 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); - constexpr index_t b_offset_3 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1)); - - constexpr index_t c_offset_0 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); - constexpr index_t c_offset_1 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1)); - constexpr index_t c_offset_2 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); - constexpr index_t c_offset_3 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1)); - - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - } - else if constexpr(H == 4 && W == 1) - { - constexpr index_t b_offset_0 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); - constexpr index_t b_offset_1 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); - constexpr index_t b_offset_2 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0)); - constexpr index_t b_offset_3 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0)); - - constexpr index_t c_offset_0 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); - constexpr index_t c_offset_1 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); - constexpr index_t c_offset_2 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0)); - constexpr index_t c_offset_3 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0)); - - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - } - else - { - static_for<0, H, 1>{}([&](auto h) { - static_for<0, W, 1>{}([&](auto w) { - constexpr index_t b_offset = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w)); - - constexpr index_t c_offset = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); - -#if 0 - c_buf(Number{}) += inner_product_with_conversion{}( - a_buf[Number{}], b_buf[Number{}]); -#else - amd_assembly_inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); -#endif - }); - }); - } - }); - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp deleted file mode 100644 index 083e47fb..00000000 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ /dev/null @@ -1,390 +0,0 @@ -#ifndef CK_AMD_XDLOPS_HPP -#define CK_AMD_XDLOPS_HPP - -#include "data_type.hpp" - -namespace ck { - -// A, B, C, cbsz, abid, blgp -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); - -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); - -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( - ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( - ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( - ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( - ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( - ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); - -template -struct intrin_mfma_f32_32x32x1f32; - -template <> -struct intrin_mfma_f32_32x32x1f32<64, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); - } -}; - -template <> -struct intrin_mfma_f32_32x32x1f32<32, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - } -}; - -template -struct intrin_mfma_f32_32x32x2f32; - -template <> -struct intrin_mfma_f32_32x32x2f32<32, 32> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x4f32; - -template <> -struct intrin_mfma_f32_16x16x4f32<16, 16> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x1f32; - -template <> -struct intrin_mfma_f32_16x16x1f32<16, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); - } -}; - -template -struct intrin_mfma_f32_4x4x1f32; - -template <> -struct intrin_mfma_f32_4x4x1f32<4, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - } -}; - -template <> -struct intrin_mfma_f32_4x4x1f32<8, 64> -{ - template - __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); - } -}; - -template -struct intrin_mfma_f32_32x32x4f16; - -template <> -struct intrin_mfma_f32_32x32x4f16<64, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); - } -}; - -template <> -struct intrin_mfma_f32_32x32x4f16<32, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - } -}; - -template -struct intrin_mfma_f32_32x32x8f16; - -template <> -struct intrin_mfma_f32_32x32x8f16<32, 32> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x16f16; - -template <> -struct intrin_mfma_f32_16x16x16f16<16, 16> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); - } -}; - -template -struct intrin_mfma_f32_16x16x4f16; - -template <> -struct intrin_mfma_f32_16x16x4f16<16, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); - } -}; - -template -struct intrin_mfma_f32_4x4x4f16; - -template <> -struct intrin_mfma_f32_4x4x4f16<4, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - } -}; - -template <> -struct intrin_mfma_f32_4x4x4f16<8, 64> -{ - template - __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); - } -}; - -#if 0 -template -struct intrin_mfma_f32_32x32x2bf16; - -template -struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> -{ - __device__ static c_vec32_4_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - reg_c.s.z = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); - reg_c.s.w = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> -{ - __device__ static c_vec32_4_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - reg_c.s.z = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); - reg_c.s.w = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> -{ - __device__ static c_vec32_2_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> -{ - __device__ static c_vec32_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> -{ - __device__ static c_vec32_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - return reg_c; - } -}; - -__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); - return reg_c; -} - -__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec4_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); - return reg_c; -} - -template -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c); -template <> -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); - return reg_c; -} - -template <> -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); - return reg_c; -} - -template -struct intrin_mfma_f32_4x4x2bf16; - -template <> -struct intrin_mfma_f32_4x4x2bf16<4, 64> -{ - __device__ static c_vec4_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); - return reg_c; - } -}; - -template <> -struct intrin_mfma_f32_4x4x2bf16<8, 64> -{ - __device__ static c_vec4_2_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); - return reg_c; - } -}; - -#endif - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp deleted file mode 100644 index 85c02a1b..00000000 --- a/composable_kernel/include/utility/common_header.hpp +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef CK_COMMON_HEADER_HPP -#define CK_COMMON_HEADER_HPP - -#include "config.hpp" -#include "array.hpp" -#include "container_helper.hpp" -#include "statically_indexed_array.hpp" -#include "container_element_picker.hpp" -#include "multi_index.hpp" -#include "data_type.hpp" -#include "data_type_enum.hpp" -#include "data_type_enum_helper.hpp" -#include "functional.hpp" -#include "functional2.hpp" -#include "functional3.hpp" -#include "functional4.hpp" -#include "enable_if.hpp" -#include "integral_constant.hpp" -#include "math.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "sequence_helper.hpp" -#include "synchronization.hpp" -#include "tuple.hpp" -#include "tuple_helper.hpp" -#include "type.hpp" -#include "magic_division.hpp" -#include "utility.hpp" -#include "c_style_pointer_cast.hpp" -#include "amd_address_space.hpp" -#include "amd_buffer_addressing.hpp" -#include "static_buffer.hpp" -#include "dynamic_buffer.hpp" - -#include "inner_product.hpp" - -// TODO: remove this -#if CK_USE_AMD_INLINE_ASM -#include "amd_inline_asm.hpp" -#endif - -#if CK_USE_AMD_XDLOPS -#include "amd_xdlops.hpp" -#endif - -#endif diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp deleted file mode 100644 index 5ee4bb9c..00000000 --- a/composable_kernel/include/utility/config.hpp +++ /dev/null @@ -1,134 +0,0 @@ -#ifndef CK_CONFIG_AMD_HPP -#define CK_CONFIG_AMD_HPP - -#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS -#include "hip/hip_runtime.h" -#include "hip/hip_fp16.h" -#endif -#include "bfloat16_dev.hpp" - -// "Constant" address space for kernel parameter -#define CONSTANT __attribute__((address_space(4))) - -// GPU target -// should enable one and only one GPU target -#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ - defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) -#error Need to define (only) one GPU target -#endif - -// launch bounds -#define CK_USE_LAUNCH_BOUNDS 1 - -#ifdef CK_USE_LAUNCH_BOUNDS -#define CK_MAX_THREAD_PER_BLOCK 256 -#define CK_MIN_BLOCK_PER_CU 2 -#endif - -// buffer resourse -#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ - defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) -#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(CK_AMD_GPU_GFX1030) -#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#endif - -// FMA instruction -#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) -#define CK_USE_AMD_V_MAC_F32 -#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \ - defined(CK_AMD_GPU_GFX1030) -#define CK_USE_AMD_V_FMAC_F32 -#define CK_USE_AMD_V_DOT2_F32_F16 -#define CK_USE_AMD_V_DOT4_I32_I8 -#endif - -// multi index -#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 - -// AMD inline asm -#ifndef CK_USE_AMD_INLINE_ASM -#define CK_USE_AMD_INLINE_ASM 1 -#endif - -// AMD inner product (DLOP) -#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM -#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 -#endif - -// AMD buffer addressing -#ifndef CK_USE_AMD_BUFFER_ADDRESSING -#define CK_USE_AMD_BUFFER_ADDRESSING 1 -#endif - -// only gfx908 support native floating point atomic add -#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD -#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0 -#endif - -// AMD XDLOPS -#ifndef CK_USE_AMD_XDLOPS -#define CK_USE_AMD_XDLOPS 0 -#endif - -// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) -#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM -#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 -#endif - -// experimental implementation -#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 -#endif - -#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 -#endif - -#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 -#endif - -// pass tensor descriptor by value or void* -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 - -// merge transformation use magic number division -#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 - -// hack: have underlying assumption that need to be satsified, otherwise it's a bug -// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be -// thread-invariant, otherwise it's a bug -// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" -#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE -#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 -#endif - -// workaround for compiler crash when compiling recursive lambda -#ifndef CK_WORKAROUND_SWDEV_275126 -#define CK_WORKAROUND_SWDEV_275126 1 -#endif - -// workaround for compiler crash when using buffer load/store for i8 -#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE -#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 -#endif - -// workaround for compiler crash when using buffer load/store for i8 -#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE -#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 -#endif - -namespace ck { - -enum InMemoryDataOperationEnum_t -{ - Set, - AtomicAdd -}; - -// index type -using index_t = int32_t; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/data_type_enum_helper.hpp b/composable_kernel/include/utility/data_type_enum_helper.hpp deleted file mode 100644 index 451ce992..00000000 --- a/composable_kernel/include/utility/data_type_enum_helper.hpp +++ /dev/null @@ -1,76 +0,0 @@ -#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP -#define CK_DATA_TYPE_ENUM_HELPER_HPP - -#include "data_type.hpp" -#include "data_type_enum.hpp" - -namespace ck { - -template -struct get_datatype_from_enum; - -template <> -struct get_datatype_from_enum -{ - using type = int8_t; -}; - -template <> -struct get_datatype_from_enum -{ - using type = int32_t; -}; - -template <> -struct get_datatype_from_enum -{ - using type = half_t; -}; - -template <> -struct get_datatype_from_enum -{ - using type = float; -}; - -template <> -struct get_datatype_from_enum -{ - using type = double; -}; - -template -struct get_datatype_enum_from_type; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float; -}; - -template <> -struct get_datatype_enum_from_type -{ - static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp deleted file mode 100644 index 886737ef..00000000 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ /dev/null @@ -1,272 +0,0 @@ -#ifndef CK_BUFFER_HPP -#define CK_BUFFER_HPP - -#include "amd_buffer_addressing.hpp" -#include "c_style_pointer_cast.hpp" -#include "enable_if.hpp" - -namespace ck { - -template -struct DynamicBuffer -{ - using type = T; - - T* p_data_; - ElementSpaceSize element_space_size_; - T invalid_element_value_ = T{0}; - - __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) - : p_data_{p_data}, element_space_size_{element_space_size} - { - } - - __host__ __device__ constexpr DynamicBuffer(T* p_data, - ElementSpaceSize element_space_size, - T invalid_element_value) - : p_data_{p_data}, - element_space_size_{element_space_size}, - invalid_element_value_{invalid_element_value} - { - } - - __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() - { - return BufferAddressSpace; - } - - __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } - - __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } - - template >::type, - typename scalar_type>::type>::value, - bool>::type = false> - __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const - { - // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; - - static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, - "wrong! X need to be multiple T"); - -#if CK_USE_AMD_BUFFER_ADDRESSING - bool constexpr use_amd_buffer_addressing = true; -#else - bool constexpr use_amd_buffer_addressing = false; -#endif - - if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing) - { - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - if constexpr(InvalidElementUseNumericalZeroValue) - { - return amd_buffer_load_invalid_element_return_return_zero, - t_per_x>( - p_data_, i, is_valid_element, element_space_size_); - } - else - { - return amd_buffer_load_invalid_element_return_customized_value, - t_per_x>( - p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); - } - } - else - { - if constexpr(InvalidElementUseNumericalZeroValue) - { - return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) : X{0}; - } - else - { - return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) - : X{invalid_element_value_}; - } - } - } - - template >::type, - typename scalar_type>::type>::value, - bool>::type = false> - __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) - { - // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; - - static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, - "wrong! X need to be multiple T"); - - if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) - { -#if CK_USE_AMD_BUFFER_ADDRESSING - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - amd_buffer_store, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); -#else - if(is_valid_element) - { - *c_style_pointer_cast(&p_data_[i]) = x; - } -#endif - } - else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) - { - if(is_valid_element) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE - *c_style_pointer_cast(&p_data_[i]) = x; -#else - // HACK: compiler would lower IR "store address_space(3)" into - // inefficient - // ISA, so I try to let compiler emit IR "store" which would be lower to - // ds_write_b128 - // TODO: remove this after compiler fix - if constexpr(is_same>::type, int8_t>::value) - { - static_assert((is_same, int8_t>::value && - is_same, int8_t>::value) || - (is_same, int8_t>::value && - is_same, int8x2_t>::value) || - (is_same, int8_t>::value && - is_same, int8x4_t>::value) || - (is_same, int8x4_t>::value && - is_same, int8x4_t>::value) || - (is_same, int8x8_t>::value && - is_same, int8x8_t>::value) || - (is_same, int8x16_t>::value && - is_same, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); - - if constexpr(is_same, int8_t>::value && - is_same, int8_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8_t>::value && - is_same, int8x2_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8_t>::value && - is_same, int8x4_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8x4_t>::value && - is_same, int8x4_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8x8_t>::value && - is_same, int8x8_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - else if constexpr(is_same, int8x16_t>::value && - is_same, int8x16_t>::value) - { - // HACK: cast pointer of x is bad - // TODO: remove this after compiler fix - *c_style_pointer_cast(&p_data_[i]) = - *c_style_pointer_cast(&x); - } - } - else - { - *c_style_pointer_cast(&p_data_[i]) = x; - } -#endif - } - } - else - { - if(is_valid_element) - { - *c_style_pointer_cast(&p_data_[i]) = x; - } - } - } - - template >::type, - typename scalar_type>::type>::value, - bool>::type = false> - __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) - { - // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; - - static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, - "wrong! X need to be multiple T"); - - static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); - -#if CK_USE_AMD_BUFFER_ADDRESSING - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - amd_buffer_atomic_add, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); -#else - if(is_valid_element) - { - atomicAdd(&p_data_[i], x); - } -#endif - } - - __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } - - __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } -}; - -template -__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) -{ - return DynamicBuffer{p, element_space_size}; -} - -template < - AddressSpaceEnum_t BufferAddressSpace, - typename T, - typename ElementSpaceSize, - typename X, - typename enable_if, remove_cvref_t>::value, bool>::type = false> -__host__ __device__ constexpr auto -make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) -{ - return DynamicBuffer{ - p, element_space_size, invalid_element_value}; -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/integral_constant.hpp b/composable_kernel/include/utility/integral_constant.hpp deleted file mode 100644 index 14f3df89..00000000 --- a/composable_kernel/include/utility/integral_constant.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef CK_INTEGRAL_CONSTANT_HPP -#define CK_INTEGRAL_CONSTANT_HPP - -namespace ck { - -template -struct integral_constant -{ - static constexpr T value = v; - typedef T value_type; - typedef integral_constant type; - __host__ __device__ constexpr operator value_type() const noexcept { return value; } - __host__ __device__ constexpr value_type operator()() const noexcept { return value; } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/multi_index.hpp b/composable_kernel/include/utility/multi_index.hpp deleted file mode 100644 index 0bb34fb1..00000000 --- a/composable_kernel/include/utility/multi_index.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef CK_MULTI_INDEX_HPP -#define CK_MULTI_INDEX_HPP - -#include "common_header.hpp" - -#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX -#include "array_multi_index.hpp" -#else -#include "statically_indexed_array_multi_index.hpp" -#endif - -#endif diff --git a/composable_kernel/include/utility/number.hpp b/composable_kernel/include/utility/number.hpp deleted file mode 100644 index f8c56436..00000000 --- a/composable_kernel/include/utility/number.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef CK_NUMBER_HPP -#define CK_NUMBER_HPP - -#include "integral_constant.hpp" - -namespace ck { - -template -using Number = integral_constant; - -template -__host__ __device__ constexpr auto operator+(Number, Number) -{ - return Number{}; -} - -template -__host__ __device__ constexpr auto operator-(Number, Number) -{ - static_assert(Y <= X, "wrong!"); - return Number{}; -} - -template -__host__ __device__ constexpr auto operator*(Number, Number) -{ - return Number{}; -} - -template -__host__ __device__ constexpr auto operator/(Number, Number) -{ - static_assert(Y > 0, "wrong!"); - return Number{}; -} - -template -__host__ __device__ constexpr auto operator%(Number, Number) -{ - static_assert(Y > 0, "wrong!"); - return Number{}; -} -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/reduction_common.hpp b/composable_kernel/include/utility/reduction_common.hpp deleted file mode 100644 index ff574c31..00000000 --- a/composable_kernel/include/utility/reduction_common.hpp +++ /dev/null @@ -1,53 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_COMMON_HPP -#define CK_REDUCTION_COMMON_HPP - -#include "reduction_enums.hpp" - -namespace ck { - -struct float_equal_one -{ - template - __device__ inline bool operator()(T x) - { - return x <= static_cast(1.0f) and x >= static_cast(1.0f); - }; -}; - -struct float_equal_zero -{ - template - __device__ inline bool operator()(T x) - { - return x <= static_cast(0.0f) and x >= static_cast(0.0f); - }; -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/utility/reduction_enums.hpp b/composable_kernel/include/utility/reduction_enums.hpp deleted file mode 100644 index e9710817..00000000 --- a/composable_kernel/include/utility/reduction_enums.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_ENUMS_HPP -#define CK_REDUCTION_ENUMS_HPP - -namespace ck { - -enum class ReduceTensorOp_t -{ - ADD = 0, - MUL = 1, - MIN = 2, - MAX = 3, - AMAX = 4, - AVG = 5, - NORM1 = 6, - NORM2 = 7, - // MUL_NO_ZEROS = 8, -}; - -enum class NanPropagation_t -{ - NOT_PROPAGATE_NAN = 0, - PROPAGATE_NAN = 1, -}; - -enum class ReduceTensorIndices_t -{ - NO_INDICES = 0, - FLATTENED_INDICES = 1, -}; - -enum class IndicesType_t -{ - INDICES_32BIT = 0, - INDICES_64BIT = 1, - INDICES_16BIT = 2, - INDICES_8BIT = 3, -}; - -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/utility/reduction_functions_binop.hpp b/composable_kernel/include/utility/reduction_functions_binop.hpp deleted file mode 100644 index 5285abee..00000000 --- a/composable_kernel/include/utility/reduction_functions_binop.hpp +++ /dev/null @@ -1,100 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_FUNCTIONS_BINOP_HPP -#define CK_REDUCTION_FUNCTIONS_BINOP_HPP - -#include "data_type.hpp" - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" - -namespace ck { -namespace detail { - -static inline __device__ bool isnan(half_t x) { return __hisnan(x); }; - -template -struct binop_with_nan_check; - -template -struct binop_with_nan_check -{ - // cppcheck-suppress constParameter - __device__ static inline void calculate(compType& accuVal, compType currVal) - { - opReduce{}(accuVal, currVal); - }; - - // The method is called when the opReduce is indexable and the user asked for indices - __device__ static inline void - // cppcheck-suppress constParameter - calculate(compType& accuVal, compType currVal, int& accuIndex, int currIndex) - { - bool changed = false; - - opReduce{}(accuVal, currVal, changed); - - if(changed) - accuIndex = currIndex; - }; -}; - -template -struct binop_with_nan_check -{ - __device__ static inline void calculate(compType& accuVal, compType currVal) - { - if(isnan(currVal)) - accuVal = currVal; - else - opReduce{}(accuVal, currVal); - }; - - // The method is called when the opReduce is indexable and the user asked for indices - __device__ static inline void - calculate(compType& accuVal, compType currVal, int& accuIndex, int currIndex) - { - if(isnan(currVal)) - { - accuVal = currVal; - accuIndex = currIndex; - } - else - { - bool changed = false; - - opReduce{}(accuVal, currVal, changed); - - if(changed) - accuIndex = currIndex; - } - }; -}; - -}; // namespace detail -}; // end of namespace ck - -#endif diff --git a/composable_kernel/include/utility/reduction_operator.hpp b/composable_kernel/include/utility/reduction_operator.hpp deleted file mode 100644 index c0afbec8..00000000 --- a/composable_kernel/include/utility/reduction_operator.hpp +++ /dev/null @@ -1,419 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_REDUCTION_OPERATOR_HPP -#define CK_REDUCTION_OPERATOR_HPP - -#include "reduction_common.hpp" - -namespace ck { - -namespace reduce { - -// Every binary operator used in reduction is represented by a templated functor class. Each functor -// class must provide at least -// three members: -// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary -// operator, "identity element" is the unique -// element in the algebraic space that doesn't affect the value of other elements -// when operated against them, and the concept is similar to zero vector in -// vector space -// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). -// 2) indexable -- boolean value indicating whether indices of the operated elements could be -// recorded. Usually, Min/Max operator could -// need to record the indices of elements. For operator like Add/Mul, no need to -// record the indices. -// 3) operator() -- the first argument of the operator must be both an input & output, and the -// corresponding variable usually stores -// the accumulated result of many operator() calls; the second argument is only an -// input. For indexable binary -// operator, the second version of operator() has third argument (which is an -// output) to indicate whether the -// accumulated value (the first argument) has changed, in which case the recorded -// accumulated index also need be -// changed. - -template -struct Add -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; - - __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } - - static constexpr bool indexable = false; -}; - -template -struct Mul -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return static_cast(1.0f); }; - - __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } - - static constexpr bool indexable = false; -}; - -template -struct Max -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Lowest(); }; - - __device__ inline constexpr void operator()(T& a, T b) const - { - if(a < b) - a = b; - } - - __device__ inline constexpr void operator()(T& a, T b, bool& changed) const - { - if(a < b) - { - a = b; - changed = true; - } - } - - static constexpr bool indexable = true; -}; - -template -struct Min -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Max(); }; - - __device__ inline constexpr void operator()(T& a, T b) const - { - if(a > b) - a = b; - } - - __device__ inline constexpr void operator()(T& a, T b, bool& changed) const - { - if(a > b) - { - a = b; - changed = true; - } - } - - static constexpr bool indexable = true; -}; - -template -struct AMax -{ - using dataType = T; - - __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; - - __device__ inline constexpr void operator()(T& a, T b) const - { - if(a < b) - a = b; - } - - __device__ inline constexpr void operator()(T& a, T b, bool& changed) const - { - if(a < b) - { - a = b; - changed = true; - } - } - - static constexpr bool indexable = true; -}; - -// Unary operators are usually called element-wisely before the reduction is executed on the -// elements. -// They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 -template -struct unary_identic -{ - __device__ unary_identic(const int divider = 1) - { - scaler = 1.0f / static_cast(divider); - }; - - __device__ inline constexpr T operator()(T a) const { return a * type_convert{}(scaler); }; - - float scaler = 1.0f; -}; - -template -struct unary_identic -{ - __device__ unary_identic(const int divider = 1) { (void)divider; }; - - __device__ inline constexpr T operator()(T a) const { return a; }; -}; - -template -struct unary_square -{ - __device__ unary_square(const int divider = 1) { scaler = 1.0f / static_cast(divider); }; - - __device__ inline constexpr T operator()(T a) const - { - a = a * a; - - return a * type_convert{}(scaler); - }; - - float scaler = 1.0f; -}; - -template -struct unary_square -{ - __device__ unary_square(const int divider = 1) { (void)divider; }; - - __device__ inline constexpr T operator()(T a) const { return a * a; }; -}; - -template -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { scaler = 1.0f / static_cast(divider); }; - - __device__ inline constexpr T operator()(T a) const - { - a = abs(a); - - return a * type_convert{}(scaler); - }; - - float scaler = 1.0f; -}; - -template -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { (void)divider; }; - - __device__ inline constexpr T operator()(T a) const { return abs(a); }; -}; - -// We know for sure that 4.0 has __habs(), but 3.0 does not have it. -// Let's assume that __habs() exists since 3.5. -#if HIP_PACKAGE_VERSION_FLAT < 3005000000 -inline __device__ __half __habs(__half x) -{ - union - { - __half half; - unsigned short u16; - } val; - val.half = x; - val.u16 = val.u16 & 0x7fff; - return val.half; -} -#endif - -template -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { scaler = 1.0f / static_cast(divider); }; - - __device__ inline half_t operator()(half_t a) const - { - a = static_cast(__habs(a)); - - return a * type_convert{}(scaler); - }; - - float scaler = 1.0f; -}; - -template <> -struct unary_abs -{ - __device__ unary_abs(const int divider = 1) { (void)divider; }; - - __device__ inline half_t operator()(half_t a) const { return static_cast(__habs(a)); }; -}; - -template -struct unary_sqrt -{ - __device__ unary_sqrt(const int divider = 1) { (void)divider; }; - - __device__ inline T operator()(T a) const { return sqrtf(a); }; -}; - -template <> -struct unary_sqrt -{ - __device__ unary_sqrt(const int divider = 1) { (void)divider; }; - - __device__ inline half_t operator()(half_t a) const { return static_cast(hsqrt(a)); }; -}; - -}; // end of namespace reduce - -// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their -// respective functor classes. -// The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in -// reduce_binary_operactor for -// easier checking by the upper-layer codes in the kernels. - -template -struct reduce_binary_operator; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Mul; - using dataType = T; - - static constexpr bool indexable = reduce::Mul::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Min; - using dataType = T; - - static constexpr bool indexable = reduce::Min::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Max; - using dataType = T; - - static constexpr bool indexable = reduce::Max::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::AMax; - using dataType = T; - - static constexpr bool indexable = reduce::Max::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -template -struct reduce_binary_operator -{ - using opType = reduce::Add; - using dataType = T; - - static constexpr bool indexable = reduce::Add::indexable; -}; - -// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary -// functor classes. -// The two unary functors are called before and afer the Reduction is executed respectively -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_identic; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_identic; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_abs; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_abs; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_square; - using posUnaryOp = reduce::unary_identic; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_square; - using posUnaryOp = reduce::unary_sqrt; -}; - -template -struct reduce_unary_operator -{ - using preUnaryOp = reduce::unary_identic; - using posUnaryOp = reduce::unary_sqrt; -}; - -} // end of namespace ck - -#endif diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp deleted file mode 100644 index 9615d10c..00000000 --- a/composable_kernel/include/utility/static_buffer.hpp +++ /dev/null @@ -1,163 +0,0 @@ -#ifndef CK_STATIC_BUFFER_HPP -#define CK_STATIC_BUFFER_HPP - -#include "statically_indexed_array.hpp" - -namespace ck { - -template -struct StaticBuffer : public StaticallyIndexedArray -{ - using type = T; - using base = StaticallyIndexedArray; - - T invalid_element_value_ = T{0}; - - __host__ __device__ constexpr StaticBuffer() : base{} {} - - __host__ __device__ constexpr StaticBuffer(T invalid_element_value) - : base{}, invalid_element_value_{invalid_element_value} - { - } - - __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() - { - return BufferAddressSpace; - } - - template - __host__ __device__ constexpr auto Get(Number i, bool is_valid_element) const - { - if constexpr(InvalidElementUseNumericalZeroValue) - { - return is_valid_element ? At(i) : T{0}; - } - else - { - return is_valid_element ? At(i) : invalid_element_value_; - } - } - - template - __host__ __device__ void Set(Number i, bool is_valid_element, const T& x) - { - if(is_valid_element) - { - At(i) = x; - } - } - - __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } - - __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } -}; - -template -struct StaticBufferV2 : public StaticallyIndexedArray -{ - using type = T; - using base = StaticallyIndexedArray; - - using VecBaseType = typename T::d1_t; - - __host__ __device__ static constexpr index_t GetVectorSize() - { - return sizeof(typename T::type) / sizeof(VecBaseType); - } - - static constexpr index_t vector_size = GetVectorSize(); - - VecBaseType invalid_element_value_ = VecBaseType{0}; - - T invalid_vec_value_ = T{0}; - - __host__ __device__ constexpr StaticBufferV2() : base{} {} - - __host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value) - : base{}, - invalid_vec_value_{invalid_element_value}, - invalid_element_value_{invalid_element_value} - { - } - - __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() - { - return BufferAddressSpace; - } - - template - __host__ __device__ constexpr auto& GetVector(Number vec_id) - { - return this->At(vec_id); - } - - template - __host__ __device__ constexpr const auto& GetVector(Number vec_id) const - { - return this->At(vec_id); - } - - template - __host__ __device__ constexpr auto& GetElement(Number i, bool) - { - constexpr auto vec_id = Number{}; - constexpr auto vec_off = Number{}; - - return this->At(vec_id).template AsType()(vec_off); - } - - template - __host__ __device__ constexpr auto GetElement(Number i, bool is_valid_element) const - { - constexpr auto vec_id = Number{}; - constexpr auto vec_off = Number{}; - - if constexpr(InvalidElementUseNumericalZeroValue) - { - return is_valid_element ? this->At(vec_id).template AsType()[vec_off] - : VecBaseType{0}; - } - else - { - return is_valid_element ? this->At(vec_id).template AsType()[vec_off] - : invalid_element_value_; - } - } - - template - __host__ __device__ constexpr auto operator[](Number i) const - { - return GetElement(i, true); - } - - template - __host__ __device__ constexpr auto& operator()(Number i) - { - return GetElement(i, true); - } - - __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } - - __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } -}; - -template -__host__ __device__ constexpr auto make_static_buffer(Number) -{ - return StaticBuffer{}; -} - -template -__host__ __device__ constexpr auto make_static_buffer(Number, T invalid_element_value) -{ - return StaticBuffer{invalid_element_value}; -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/statically_indexed_array.hpp b/composable_kernel/include/utility/statically_indexed_array.hpp deleted file mode 100644 index f30a3a9e..00000000 --- a/composable_kernel/include/utility/statically_indexed_array.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP -#define CK_STATICALLY_INDEXED_ARRAY_HPP - -#include "functional2.hpp" -#include "sequence.hpp" -#include "tuple.hpp" - -namespace ck { - -namespace detail { - -template -__host__ __device__ constexpr auto generate_same_type_tuple() -{ - return generate_tuple([](auto) -> T { return T{}; }, Number{}); -} - -template -using same_type_tuple = decltype(generate_same_type_tuple()); - -} // namespace detail - -template -using StaticallyIndexedArray = detail::same_type_tuple; - -template -__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) -{ - return StaticallyIndexedArray(x, static_cast(xs)...); -} - -// make empty StaticallyIndexedArray -template -__host__ __device__ constexpr auto make_statically_indexed_array() -{ - return StaticallyIndexedArray(); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/synchronization.hpp b/composable_kernel/include/utility/synchronization.hpp deleted file mode 100644 index da74f207..00000000 --- a/composable_kernel/include/utility/synchronization.hpp +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef CK_SYNCHRONIZATION_AMD_HPP -#define CK_SYNCHRONIZATION_AMD_HPP - -#include "config.hpp" - -namespace ck { - -__device__ void block_sync_lds() -{ -#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); -#else - __syncthreads(); -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp deleted file mode 100644 index 9f34e044..00000000 --- a/composable_kernel/include/utility/utility.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CK_UTILITY_HPP -#define CK_UTILITY_HPP - -#include "config.hpp" - -namespace ck { - -__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } - -__device__ index_t get_block_1d_id() { return blockIdx.x; } - -} // namespace ck - -#endif diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 09a7fffa..00000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,370 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v1r2.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; -constexpr index_t M1PerThread = CK_PARAM_M1PerThread; -constexpr index_t N1PerThread = CK_PARAM_N1PerThread; -constexpr index_t KPerThread = CK_PARAM_KPerThread; -constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; -constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; -constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; -constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; - -using ABlockTransferThreadSliceLengths_K_M0_M1 = - Sequence; -using ABlockTransferThreadClusterLengths_K_M0_M1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_M1 = - CK_PARAM_ABlockTransferDstScalarPerVector_M1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K_N0_N1 = - Sequence; -using BBlockTransferThreadClusterLengths_K_N0_N1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_N1 = - CK_PARAM_BBlockTransferDstScalarPerVector_N1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); - -extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k_m0_m1_grid_desc, - void* p_b_k_n0_n1_grid_desc, - void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW)); - - const auto a_k_m_grid_desc = descs[I0]; - const auto b_k_n_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AKMGridDesc = decltype(a_k_m_grid_desc); - using BKNGridDesc = decltype(b_k_n_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r2; - - auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc; - *static_cast(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; - *static_cast( - p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - }; -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k_m0_m1_grid_desc, - const void CONSTANT* p_b_k_n0_n1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1)); - - constexpr auto a_k_m_grid_desc = descs[I0]; - constexpr auto b_k_n_grid_desc = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AKMGridDesc = decltype(a_k_m_grid_desc); - using BKNGridDesc = decltype(b_k_n_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r2; - - constexpr auto a_k_m0_m1_grid_desc_tmp = - GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - constexpr auto b_k_n0_n1_grid_desc_tmp = - GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k_m0_m1_grid_desc = - *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); - const auto b_k_n0_n1_grid_desc = - *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 51d85261..00000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,358 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; - -constexpr index_t MPerWave = CK_PARAM_MPerWave; -constexpr index_t NPerWave = CK_PARAM_NPerWave; -constexpr index_t MRepeat = CK_PARAM_MRepeat; -constexpr index_t NRepeat = CK_PARAM_NRepeat; -constexpr index_t K1 = CK_PARAM_K1; - -using ABlockTransferThreadSliceLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_K1 = - CK_PARAM_ABlockTransferDstScalarPerVector_K1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_K1 = - CK_PARAM_BBlockTransferDstScalarPerVector_K1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k0_m_k1_grid_desc, - void* p_b_k0_n_k1_grid_desc, - void* p_c_m0_m1_m2_n_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW), - Number{}); - - const auto a_k0_m_k1_grid_desc = descs[I0]; - const auto b_k0_n_k1_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast*>(p_a_k0_m_k1_grid_desc) = - a_k0_m_k1_grid_desc; - *static_cast*>(p_b_k0_n_k1_grid_desc) = - b_k0_n_k1_grid_desc; - *static_cast(p_c_m0_m1_m2_n_grid_desc) = - c_m0_m1_m2_n_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); - - constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; - constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using BGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - constexpr auto c_m0_m1_m2_n_grid_desc_tmp = - GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); -}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp deleted file mode 100644 index 30e4c518..00000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp +++ /dev/null @@ -1,357 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; - -constexpr index_t MPerWave = CK_PARAM_MPerWave; -constexpr index_t NPerWave = CK_PARAM_NPerWave; -constexpr index_t MRepeat = CK_PARAM_MRepeat; -constexpr index_t NRepeat = CK_PARAM_NRepeat; -constexpr index_t K1 = CK_PARAM_K1; - -using ABlockTransferThreadSliceLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_K1 = - CK_PARAM_ABlockTransferDstScalarPerVector_K1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_K1 = - CK_PARAM_BBlockTransferDstScalarPerVector_K1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( - int n, - int hi, - int wi, - int c, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k0_m_k1_grid_desc, - void* p_b_k0_n_k1_grid_desc, - void* p_c_m0_m1_m2_n_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c)); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c)); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW), - Number{}); - - const auto a_k0_m_k1_grid_desc = descs[I0]; - const auto b_k0_n_k1_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast*>(p_a_k0_m_k1_grid_desc) = - a_k0_m_k1_grid_desc; - *static_cast*>(p_b_k0_n_k1_grid_desc) = - b_k0_n_k1_grid_desc; - *static_cast(p_c_m0_m1_m2_n_grid_desc) = - c_m0_m1_m2_n_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); - constexpr auto wei_k_y_x_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256)); - constexpr auto out_n_ho_wo_k_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); - - constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; - constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - constexpr auto c_m0_m1_m2_n_grid_desc_tmp = - GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); -}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 71239e0e..00000000 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,400 +0,0 @@ -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_contraction_dlops_v1r2.hpp" -#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr auto GN0 = Number{}; -constexpr auto GK1 = Number{}; - -constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; -constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; -constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; - -constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; -constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; -constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; - -using BM10BN10ThreadClusterBM10Xs = Sequence; -using BM10BN10ThreadClusterBN10Xs = Sequence; - -using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>; -using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>; -using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - Sequence; -using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; - -using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>; -using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>; -using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - Sequence; -using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; - -using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>; -constexpr index_t CThreadTransferSrcDstVectorDim = 5; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBlockLoop); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); - -extern "C" __global__ void -convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_, - int C_, - int Hi_, - int Wi_, - int K_, - int Y_, - int X_, - int ConvStrideH_, - int ConvStrideW_, - int ConvDilationH_, - int ConvDilationW_, - int InLeftPadH_, - int InLeftPadW_, - int InRightPadH_, - int InRightPadW_, - void* p_desc_tuple) -{ - index_t N = static_cast(N_); - index_t C = static_cast(C_); - index_t Hi = static_cast(Hi_); - index_t Wi = static_cast(Wi_); - index_t K = static_cast(K_); - index_t Y = static_cast(Y_); - index_t X = static_cast(X_); - index_t ConvStrideH = static_cast(ConvStrideH_); - index_t ConvStrideW = static_cast(ConvStrideW_); - index_t ConvDilationH = static_cast(ConvDilationH_); - index_t ConvDilationW = static_cast(ConvDilationW_); - index_t InLeftPadH = static_cast(InLeftPadH_); - index_t InLeftPadW = static_cast(InLeftPadW_); - index_t InRightPadH = static_cast(InRightPadH_); - index_t InRightPadW = static_cast(InRightPadW_); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t Ho = - (Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1; - const index_t Wo = - (Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1; - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi)); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X)); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); - - const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(ConvStrideH, ConvStrideW), - make_tuple(ConvDilationH, ConvDilationW), - make_tuple(InLeftPadH, InLeftPadW), - make_tuple(InRightPadH, InRightPadW), - GN0, - GK1); - - const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; - const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); - using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); - using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using CGridStepHacks = decltype(make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - - using BGridMoveSliceWindowStepHacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; - - using GridwiseContraction = - GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - AGridDesc_GK0_GM0_GM1_GK1, - BGridDesc_GK0_GN0_GN1_GK1, - CGridDesc_GM0_GM1_GN0_GN1, - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks>; - - if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) - { - auto desc_tuple = - make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( - a_grid_desc_gk0_gm0_gm1_gk1), - GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( - b_grid_desc_gk0_gn0_gn1_gk1), - GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1), - GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1)); - - *static_cast(p_desc_tuple) = desc_tuple; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_desc_tuple) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - GN0, - GK1); - - constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; - constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); - using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); - using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - - using AGridStepHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using BGridStepHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - - using CGridStepHacks = decltype(make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - - using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - - using BGridMoveSliceWindowStepHacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; - - using GridwiseContraction = - GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - AGridDesc_GK0_GM0_GM1_GK1, - BGridDesc_GK0_GN0_GN1_GK1, - CGridDesc_GM0_GM1_GN0_GN1, - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks>; - - using AGridDesc_GK0_GM0_GM10_GM11_GK1 = - decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( - a_grid_desc_gk0_gm0_gm1_gk1)); - using BGridDesc_GK0_GN0_GN10_GN11_GK1 = - decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( - b_grid_desc_gk0_gn0_gn1_gk1)); - using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = - decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1)); - using CGridBlockCluster_BlockId_To_GM10_GN10 = - decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1)); - - using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{}, - BGridDesc_GK0_GN0_GN10_GN11_GK1{}, - CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, - CGridBlockCluster_BlockId_To_GM10_GN10{})); - - const auto desc_tuple = - *reinterpret_cast(cast_pointer_to_generic_address_space(p_desc_tuple)); - - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2]; - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3]; - - constexpr index_t shared_block_size = - GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseContraction::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10, - integral_constant{}, - integral_constant{}); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp deleted file mode 100644 index ca6b4159..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp +++ /dev/null @@ -1,271 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp deleted file mode 100644 index a3daeaf1..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp deleted file mode 100644 index 81899dfb..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp +++ /dev/null @@ -1,276 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_multiblock.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - const index_t reduceSizePerBlock = - (((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) * - copySliceLen; - - if constexpr(src2d_need_padding) - { - const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_dst_global; - (void)indices_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - BlkGroupSize, - alpha, - static_cast(p_src_global), - beta, - static_cast(ws_buf1_global), - static_cast(ws_buf2_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp deleted file mode 100644 index 0e578f4d..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp +++ /dev/null @@ -1,310 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_multiblock.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - const index_t reduceSizePerBlock = - (((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) * - copySliceLen; - - if constexpr(src2d_need_padding) - { - const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_dst_global; - (void)indices_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - BlkGroupSize, - alpha, - static_cast(p_src_global), - beta, - static_cast(ws_buf1_global), - static_cast(ws_buf2_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp deleted file mode 100644 index e63a1254..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp +++ /dev/null @@ -1,284 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstdDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp deleted file mode 100644 index 698f7400..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp deleted file mode 100644 index 4a607372..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp +++ /dev/null @@ -1,285 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto one_dim_srcDesc = transform_tensor_descriptor( - srcDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - auto src2dDesc = transform_tensor_descriptor( - one_dim_srcDesc, - make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - constexpr int invariantLen = 1; - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); - - static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))), - make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_src2dDesc = - transform_tensor_descriptor(ref_one_dim_srcDesc, - make_tuple(make_unmerge_transform( - make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp deleted file mode 100644 index a6415279..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,320 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; -constexpr index_t num_invariantDims = srcDims - num_toReduceDims; - -using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; -using toReduceDims = typename arithmetic_sequence_gen::type; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, - int BlkGroupSize, - int inLength0, - int inLength1, - int inLength2, - int inLength3, - int inLength4, - int inLength5, - int inStride0, - int inStride1, - int inStride2, - int inStride3, - int inStride4, - int inStride5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; - const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(srcLengths, invariantDims{}); - - auto src2dDesc = - transform_tensor_descriptor(srcDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(toReduceDimLengths)), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_toReduceDimLengths = - typename uniform_sequence_gen::type{}; - static constexpr auto ref_invariantDimLengths = - typename uniform_sequence_gen::type{}; - - static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; - - // don't have to use accurate strides to get an expected referrence type - static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); - - static constexpr auto ref_src2dDesc = transform_tensor_descriptor( - ref_srcDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)), - make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))), - make_tuple(invariantDims{}, toReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); - static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); -}; - -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, - int BlkGroupSize, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)BlkGroupSize; - (void)ws_buf2_bytes_offset; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - constexpr int RunId = need_indices ? 2 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(p_src_global), - beta, - static_cast(p_dst_global), - static_cast(nullptr), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp deleted file mode 100644 index 7e9d4661..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp +++ /dev/null @@ -1,205 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -extern "C" __global__ void -gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const index_t invariantLen = dstDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; -}; - -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = make_tuple(8); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp deleted file mode 100644 index 3f37d01e..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,263 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_blockwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, - int BlkGroupSize, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)GridSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const index_t invariantLen = dst1dDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; - - if constexpr(src2d_need_padding) - { - const auto srcPad = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pass_through_transform(invariantLen), - make_pad_transform(toReduceLen, 0, srcPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = - make_tuple_from_seq(typename uniform_sequence_gen::type{}); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(ref_tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); - - // used by the BlockWise and MultiBlock method - using refType_src2dDesc_padded_34 = decltype( - transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pass_through_transform(ref_invariantLen), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp deleted file mode 100644 index 77841d13..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp +++ /dev/null @@ -1,222 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -extern "C" __global__ void -gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const index_t invariantLen = dstDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = make_tuple(8); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp deleted file mode 100644 index 2de461ad..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,277 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, - int BlkGroupSize, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const index_t invariantLen = dst1dDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = GredThreadBufferLength; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = - make_tuple_from_seq(typename uniform_sequence_gen::type{}); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(ref_tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp deleted file mode 100644 index 1ba5e496..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp +++ /dev/null @@ -1,221 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -extern "C" __global__ void -gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const auto tupleDstLengths = make_tuple(1); - const auto tupleDstStrides = make_tuple(1); - - auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - const index_t invariantLen = dstDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dstDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dstDesc; - } -}; - -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = make_tuple(8); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dstDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dstDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp deleted file mode 100644 index aef1545f..00000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp +++ /dev/null @@ -1,279 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "config.hpp" -#include "number.hpp" -#include "sequence.hpp" -#include "tensor_descriptor_helper.hpp" -#include "data_type_enum_helper.hpp" -#include "reduction_common.hpp" -#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" - -using namespace ck; - -using srcDataType = - typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; -using dstDataType = - typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; -using compType = - typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable - -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); -constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 - ? NanPropagation_t::NOT_PROPAGATE_NAN - : NanPropagation_t::PROPAGATE_NAN; -constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 - ? ReduceTensorIndices_t::NO_INDICES - : ReduceTensorIndices_t::FLATTENED_INDICES; - -constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); -constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); - -constexpr bool indexable = reduce_binary_operator::indexable; -constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); - -constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable - -// helper functions using variadic template arguments -template -__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence) -{ - return make_tuple(static_cast(lengths[Ns])...); -}; - -template -__device__ static auto make_tuple_from_array(const int* lengths, Number) -{ - static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); - - constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; - - return make_tuple_from_array_and_index_seq(lengths, index_seq); -}; - -template -__device__ static constexpr auto make_tuple_from_seq(Sequence) -{ - return make_tuple(Ns...); -}; - -extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, - int BlkGroupSize, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, - void* __restrict__ ws_global) -{ - (void)BlkGroupSize; - - void* p_src2dDesc = ws_global; - void* p_dst1dDesc = static_cast(ws_global) + 2048; - - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; - - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); - - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const index_t invariantLen = dst1dDesc.GetLength(Number<0>{}); - const index_t toReduceLen = BlkGroupSize; - - auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); - - constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; - - if constexpr(src2d_need_padding) - { - const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; - const auto srcPad2 = - ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; - - auto src2dDesc_2 = - transform_tensor_descriptor(src2dDesc, - make_tuple(make_pad_transform(invariantLen, 0, srcPad1), - make_pad_transform(toReduceLen, 0, srcPad2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_src2dDesc) = src2dDesc; - } - - if constexpr(dst1d_need_padding) - { - const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; - auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, - make_tuple(make_pad_transform(invariantLen, 0, dstPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc_2; - } - else - { - if(get_thread_local_1d_id() == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; - } -}; - -template -struct get_ref_desc_types -{ - static constexpr auto ref_tupleDstLengths = - make_tuple_from_seq(typename uniform_sequence_gen::type{}); - static constexpr auto ref_dstDesc = - make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); - - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(ref_tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{}); - static constexpr index_t ref_toReduceLen = 8; - - static constexpr auto ref_src2dDesc = - make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); - - using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); - - // used by the DirectThreadWise and DirectWarpWise method - using refType_src2dDesc_padded_12 = - decltype(transform_tensor_descriptor(ref_src2dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2), - make_pad_transform(ref_toReduceLen, 0, 2)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}))); - - using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, - make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{}))); -}; - -using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; - -template -static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_src2dDesc)); - else - return (*reinterpret_cast(p_src2dDesc)); -}; - -template -static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) -{ - if constexpr(need_padding) - return (*reinterpret_cast(p_dst1dDesc)); - else - return (*reinterpret_cast(p_dst1dDesc)); -}; - -extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, - float alpha, - const void* __restrict__ p_src_global, - float beta, - void* __restrict__ p_dst_global, - const void CONSTANT* ws_global, - long ws_buf2_bytes_offset, - void* __restrict__ indices_global) -{ - (void)p_src_global; - - const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); - const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; - void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); - - const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); - const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); - - using gridwise_2d_reduce = - GridwiseReduction_xy_to_x_direct_warpwise; - - void* const ws_buf2_global = - ws_buf2_bytes_offset > 0 - ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) - : nullptr; - - constexpr int RunId = need_indices ? 3 : 1; - gridwise_2d_reduce::template Run( - src2dDesc, - dst1dDesc, - origReduceLen, - alpha, - static_cast(ws_buf1_global), - beta, - static_cast(p_dst_global), - static_cast(ws_buf2_global), - static_cast(indices_global)); -}; diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 00000000..5d123edb --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,3 @@ +ROCmSoftwarePlatform/rocm-recipes +# 1.90+ +danmar/cppcheck@dd05839a7e63ef04afd34711cb3e1e0ef742882f \ No newline at end of file diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt new file mode 100644 index 00000000..c03c454c --- /dev/null +++ b/example/01_gemm/CMakeLists.txt @@ -0,0 +1,8 @@ +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md new file mode 100644 index 00000000..226783b0 --- /dev/null +++ b/example/01_gemm/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_xdl``` + +## Run ```example_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_gemm_xdl 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s +``` diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp new file mode 100644 index 00000000..d9d6c303 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp new file mode 100644 index 00000000..09150185 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp new file mode 100644 index 00000000..eec3ce16 --- /dev/null +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp new file mode 100644 index 00000000..24a7183c --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using CDataType = BF16; +using AccDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp new file mode 100644 index 00000000..233c8d92 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; +// clang-format on + +using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle + // clang-format off +//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using DeviceGemmInstance = DeviceGemmInstance0; + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp new file mode 100644 index 00000000..3d49a3f3 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using ADataType = double; +using BDataType = double; +using CDataType = double; +using AccDataType = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl +//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if 0 + < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>; +#else + < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; +#endif + // clang-format on + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) +{ + os << "[" << std::endl; + for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++) + { + os << "["; + for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++) + { + os << std::setw(4) << static_cast(matrix(x, y)); + } + os << "]" << std::endl; + } + os << "]"; + return os; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "data type: " << typeid(ADataType{}).name() << std::endl; + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + +#if 0 + { + show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl; + show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl; + show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl; + show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; + } +#endif + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp new file mode 100644 index 00000000..639ae271 --- /dev/null +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout + BLayout, // typename BLayout + CLayout, // typename CLayout + ADataType, // typename ADataType + BDataType, // typename BDataType + CDataType, // typename CDataType + AccDataType, // typename GemmAccDataType + CShuffleDataType, // typename CShuffleDataType + PassThrough, // typename AElementwiseOperation + PassThrough, // typename BElementwiseOperation + PassThrough, // typename CElementwiseOperation + GemmDefault, // GemmSpecialization GemmSpec + 1, // index_t NumGemmKPrefetchStage + 256, // index_t BlockSize + 256, // index_t MPerBlock + 128, // index_t NPerBlock + 64, // index_t KPerBlock + 16, // index_t AK1 + 16, // index_t BK1 + 32, // index_t MPerXDL + 32, // index_t NPerXDL + 4, // index_t MXdlPerWave + 2, // index_t NXdlPerWave + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt new file mode 100644 index 00000000..10ec0f1a --- /dev/null +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/02_gemm_bilinear/README.md b/example/02_gemm_bilinear/README.md new file mode 100644 index 00000000..9eb87e1e --- /dev/null +++ b/example/02_gemm_bilinear/README.md @@ -0,0 +1,28 @@ +# Instructions for ```example_gemm_bilinear_xdl_fp16``` + +## Run ```example_gemm_bilinear_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE +#arg11 to 12: alpha, beta +./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 +``` +Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s +error: 0 +max_diff: 0, 558.5, 558.5 +``` diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp new file mode 100644 index 00000000..83eed76f --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt new file mode 100644 index 00000000..35c54aba --- /dev/null +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/README.md b/example/03_gemm_bias_relu/README.md new file mode 100644 index 00000000..f28a9a07 --- /dev/null +++ b/example/03_gemm_bias_relu/README.md @@ -0,0 +1,10 @@ +# Instructions for ```example_gemm_bias_relu_xdl_fp16``` + +## Run ```example_gemm_bias_relu_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE +./bin/example_gemm_bias_relu_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 +``` diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp new file mode 100644 index 00000000..ff0a8220 --- /dev/null +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// C = A * B +// E = Relu(C + D); +struct AddRelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const + { + const ck::half_t x = c + d; + + e = x > 0 ? x : 0; + } +}; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(EDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 00000000..754de47c --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/README.md b/example/04_gemm_add_add_fastgelu/README.md new file mode 100644 index 00000000..08a55fb9 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16``` + +## Run ```example_gemm_add_add_fastgelu_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" +./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} +d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> +``` diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp new file mode 100644 index 00000000..5f054069 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD0 = std::stoi(argv[9]); + StrideD1 = std::stoi(argv[10]); + StrideE = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " + "StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(D0DataType) * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/06_conv2d_fwd_bias_relu/CMakeLists.txt b/example/06_conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 00000000..4e1dd1f3 --- /dev/null +++ b/example/06_conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp) +target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_util) diff --git a/example/06_conv2d_fwd_bias_relu/README.md b/example/06_conv2d_fwd_bias_relu/README.md new file mode 100644 index 00000000..4c30563e --- /dev/null +++ b/example/06_conv2d_fwd_bias_relu/README.md @@ -0,0 +1,22 @@ +# Instructions for ```example_conv_xdl_bias_relu``` + +## Run ```example_conv_xdl_bias_relu``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_conv_xdl_bias_relu 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +bias_k: dim 1, lengths {256}, strides {1} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.39009 ms, 105.581 TFlops, 239.981 GB/s +``` diff --git a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp new file mode 100644 index 00000000..0f24045c --- /dev/null +++ b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp" + +namespace { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +// clang-format off +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + MemorySet, // OutGlobalMemoryDataOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "Following arguments:\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int num_dim_spatial = 2; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 4; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 4; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + const int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + + if(argc >= 5) + { + params = ParseConvParams(argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + // bias: assume contiguous 1d vector + Tensor bias( + HostTensorDescriptor(std::vector({static_cast(params.K_)}))); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths) + + sizeof(OutDataType) * (params.K_); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + bias, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt b/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt new file mode 100644 index 00000000..b4dd39d8 --- /dev/null +++ b/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -0,0 +1,3 @@ +# FIXME: should fix validation failure +add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) +target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util) diff --git a/example/07_conv2d_fwd_bias_relu_add/README.md b/example/07_conv2d_fwd_bias_relu_add/README.md new file mode 100644 index 00000000..99afcae9 --- /dev/null +++ b/example/07_conv2d_fwd_bias_relu_add/README.md @@ -0,0 +1,24 @@ +# Instructions for ```example_conv_xdl_bias_relu_add``` + + +## Run ```example_conv_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_conv_xdl_bias_relu_add 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +bias_k: dim 1, lengths {256}, strides {1} +resi_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.44711 ms, 101.421 TFlops, 289.218 GB/s +``` diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp new file mode 100644 index 00000000..27559665 --- /dev/null +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp" + +namespace { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +// clang-format off +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "Following arguments:\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int num_dim_spatial = 2; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 4; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 4; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + const int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + + if(argc >= 5) + { + params = ParseConvParams(argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + // bias: assume contiguous 1d vector + Tensor bias( + HostTensorDescriptor(std::vector({static_cast(params.K_)}))); + + // residual: assume same layout as output tensor + Tensor residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "residual: " << residual.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + weights.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + residual.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + residual.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace()); + DeviceMem resi_device_buf(sizeof(OutDataType) * residual.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + resi_device_buf.ToDevice(residual.mData.data()); + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(resi_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device operator with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths) + + sizeof(OutDataType) * (params.K_) + + sizeof(OutDataType) * + (params.N_ * params.K_ * output_spatial_lengths[0] * output_spatial_lengths[1]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + bias, + residual, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt new file mode 100644 index 00000000..b373d1d6 --- /dev/null +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -0,0 +1,6 @@ +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) diff --git a/example/09_convnd_fwd/README.md b/example/09_convnd_fwd/README.md new file mode 100644 index 00000000..9ab5fee5 --- /dev/null +++ b/example/09_convnd_fwd/README.md @@ -0,0 +1,32 @@ +# Instructions for ```example_convnd_fwd_xdl``` + +## Run ```example_convnd_fwd_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: N spatial dimensions (default 2) +#Following arguments (depending on number of spatial dims): +# N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) +./bin/example_convnd_fwd_xdl 0 1 100 +``` + +Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) +``` +input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{432, 165888, 4} +arg.b_grid_desc_k0_n_k1_{432, 256, 4} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 100 times... +Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s +``` diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp new file mode 100644 index 00000000..6ff4bfbe --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +ck::tensor_operation::device::ConvParams +parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) +{ + const ck::index_t N = std::stoi(argv[arg_idx++]); + const ck::index_t K = std::stoi(argv[arg_idx++]); + const ck::index_t C = std::stoi(argv[arg_idx++]); + + std::vector filter_spatial_lengths(num_dim_spatial); + std::vector input_spatial_lengths(num_dim_spatial); + std::vector conv_filter_strides(num_dim_spatial); + std::vector conv_filter_dilations(num_dim_spatial); + std::vector input_left_pads(num_dim_spatial); + std::vector input_right_pads(num_dim_spatial); + + for(int i = 0; i < num_dim_spatial; ++i) + { + filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return ck::tensor_operation::device::ConvParams{num_dim_spatial, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; +} + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +template +int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params, + bool do_verification, + int init_method, + bool time_kernel) +{ + auto f_nhwc_host_tensor_descriptor = + [](ck::index_t n, ck::index_t c, std::vector spatial_lengths) { + std::vector nhwc_lengths{static_cast(n), + static_cast(c)}; + nhwc_lengths.insert( + nhwc_lengths.begin() + 1, spatial_lengths.begin(), spatial_lengths.end()); + + return HostTensorDescriptor(nhwc_lengths); + }; + + Tensor input( + f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); + Tensor weight( + f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_)); + Tensor host_output( + f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); + Tensor device_output( + f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + // do GEMM + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = params.GetFlops(); + std::size_t num_btype = params.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvNDFwdInstance(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(device_output.mData.data()); + + return ck::utils::check_err(host_output.mData, + device_output.mData, + "Error: incorrect results!", + 1e-5f, + 1e-4f) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp new file mode 100644 index 00000000..4882166e --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl< + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +int main(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::tensor_operation::device::ConvParams params; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + if(num_dim_spatial == 1) + { + return run_conv_fwd<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<1>, + ReferenceConvNDFwdInstance<1>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 2) + { + return run_conv_fwd<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<2>, + ReferenceConvNDFwdInstance<2>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 3) + { + return run_conv_fwd<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<3>, + ReferenceConvNDFwdInstance<3>>( + params, do_verification, init_method, time_kernel); + } + + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp new file mode 100644 index 00000000..6bfbd20f --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl< + NumDimSpatial, // NumDimSpatial + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation = + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd; + +int main(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::tensor_operation::device::ConvParams params{ + 2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + + params = parse_conv_params(num_dim_spatial, 5, argv); + } + + if(num_dim_spatial == 1) + { + return run_conv_fwd_nhwc<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<1>, + ReferenceConvNDFwdInstance<1>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 2) + { + return run_conv_fwd_nhwc<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<2>, + ReferenceConvNDFwdInstance<2>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 3) + { + return run_conv_fwd_nhwc<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<3>, + ReferenceConvNDFwdInstance<3>>( + params, do_verification, init_method, time_kernel); + } + + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp new file mode 100644 index 00000000..cd70ca66 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl< + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 4, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +int main(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::tensor_operation::device::ConvParams params; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + if(num_dim_spatial == 1) + { + return run_conv_fwd<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<1>, + ReferenceConvNDFwdInstance<1>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 2) + { + return run_conv_fwd<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<2>, + ReferenceConvNDFwdInstance<2>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 3) + { + return run_conv_fwd<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<3>, + ReferenceConvNDFwdInstance<3>>( + params, do_verification, init_method, time_kernel); + } + + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp new file mode 100644 index 00000000..242265be --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +using InDataType = double; +using WeiDataType = double; +using OutDataType = double; +using AccDataType = double; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl< + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 2, // K1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 2, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +int main(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::tensor_operation::device::ConvParams params; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + if(num_dim_spatial == 1) + { + return run_conv_fwd<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<1>, + ReferenceConvNDFwdInstance<1>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 2) + { + return run_conv_fwd<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<2>, + ReferenceConvNDFwdInstance<2>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 3) + { + return run_conv_fwd<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<3>, + ReferenceConvNDFwdInstance<3>>( + params, do_verification, init_method, time_kernel); + } + + return 0; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp new file mode 100644 index 00000000..cd71a384 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; +using AccDataType = int32_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl< + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 16, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +int main(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::tensor_operation::device::ConvParams params; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + if(num_dim_spatial == 1) + { + return run_conv_fwd<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<1>, + ReferenceConvNDFwdInstance<1>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 2) + { + return run_conv_fwd<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<2>, + ReferenceConvNDFwdInstance<2>>( + params, do_verification, init_method, time_kernel); + } + else if(num_dim_spatial == 3) + { + return run_conv_fwd<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNDFwdInstance<3>, + ReferenceConvNDFwdInstance<3>>( + params, do_verification, init_method, time_kernel); + } + + return 0; +} diff --git a/example/10_conv2d_bwd_data/CMakeLists.txt b/example/10_conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 00000000..17aca148 --- /dev/null +++ b/example/10_conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp) +target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_util) diff --git a/example/10_conv2d_bwd_data/README.md b/example/10_conv2d_bwd_data/README.md new file mode 100644 index 00000000..7503ff6d --- /dev/null +++ b/example/10_conv2d_bwd_data/README.md @@ -0,0 +1,47 @@ +# Instructions for ```example_conv2d_bwd_data_xdl``` Example + + +## Run ```example_conv2d_bwd_data_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_conv2d_bwd_data_xdl 0 1 5 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 256, 71, 71}, strides {1290496, 1, 18176, 256} +wei_k_c_y_x: dim 4, lengths {256, 256, 3, 3}, strides {2304, 1, 768, 256} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{128, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{32, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 2.45966 ms, 79.5597 TFlops, 169.325 GB/s +``` diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp new file mode 100644 index 00000000..108e2d6f --- /dev/null +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +using DeviceConvBwdDataInstance = ck::tensor_operation::device:: + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdData; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 256; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 19) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + C = std::stoi(argv[6]); + Y = std::stoi(argv[7]); + X = std::stoi(argv[8]); + Hi = std::stoi(argv[9]); + Wi = std::stoi(argv[10]); + conv_stride_h = std::stoi(argv[11]); + conv_stride_w = std::stoi(argv[12]); + conv_dilation_h = std::stoi(argv[13]); + conv_dilation_w = std::stoi(argv[14]); + in_left_pad_h = std::stoi(argv[15]); + in_left_pad_w = std::stoi(argv[16]); + in_right_pad_h = std::stoi(argv[17]); + in_right_pad_w = std::stoi(argv[18]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + }; + + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo)); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor in_n_c_hi_wi_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + + // reset input to zero + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); + + // do GEMM + auto conv = DeviceConvBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvBwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + return ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData) + ? 0 + : 1; + } + return 0; +} diff --git a/example/11_conv2d_bwd_weight/CMakeLists.txt b/example/11_conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000..3d771b55 --- /dev/null +++ b/example/11_conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp) +target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_util) diff --git a/example/11_conv2d_bwd_weight/README.md b/example/11_conv2d_bwd_weight/README.md new file mode 100644 index 00000000..c7627427 --- /dev/null +++ b/example/11_conv2d_bwd_weight/README.md @@ -0,0 +1,25 @@ +# Instructions for ```example_conv2d_bwd_weight_xdl``` Example + +## Run ```example_conv2d_bwd_weight_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: is show log (0=no, 1=yes) +#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k +./bin/example_conv2d_bwd_weight_xdl 0 1 5 0 4 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 1024, 14, 14}, strides {200704, 1, 14336, 1024} +wei_k_c_y_x: dim 4, lengths {256, 1024, 3, 3}, strides {9216, 1, 3072, 1024} +out_n_k_ho_wo: dim 4, lengths {128, 256, 6, 6}, strides {9216, 1, 1536, 256} +arg.a_grid_desc_kbatch_k0_m_k1_{4, 144, 256, 8} +arg.b_grid_desc_kbatch_k0_n_k1_{4, 144, 9216, 8} +arg.c_grid_desc_m_n_{ 256, 9216} +launch_and_time_kernel: grid_dim {576, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.401084 ms, 54.2112 TFlops, 145.75 GB/s +``` diff --git a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp new file mode 100644 index 00000000..8b4f9c21 --- /dev/null +++ b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +using DeviceConvBwdWeightInstance = ck::tensor_operation::device:: + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int do_log = 0; + int split_k = 4; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 1024; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 14; + ck::index_t Wi = 14; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc == 21) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + + N = std::stoi(argv[6]); + K = std::stoi(argv[7]); + C = std::stoi(argv[8]); + Y = std::stoi(argv[9]); + X = std::stoi(argv[10]); + Hi = std::stoi(argv[11]); + Wi = std::stoi(argv[12]); + conv_stride_h = std::stoi(argv[13]); + conv_stride_w = std::stoi(argv[14]); + conv_dilation_h = std::stoi(argv[15]); + conv_dilation_w = std::stoi(argv[16]); + in_left_pad_h = std::stoi(argv[17]); + in_left_pad_w = std::stoi(argv[18]); + in_right_pad_h = std::stoi(argv[19]); + in_right_pad_w = std::stoi(argv[20]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: is show log (0=no, 1=yes)\n"); + printf("arg5: split-k \n"); + printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t H, + std::size_t W, + auto layout) { + if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor wei_k_c_y_x_device_result( + f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + } + wei_k_c_y_x_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data()); + + // do GEMM + auto conv = DeviceConvBwdWeightInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + if(!conv.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvBwdWeightInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + } + return 0; +} diff --git a/example/12_reduce/CMakeLists.txt b/example/12_reduce/CMakeLists.txt new file mode 100644 index 00000000..9045a78a --- /dev/null +++ b/example/12_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) +add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp) diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md new file mode 100644 index 00000000..826d2f6c --- /dev/null +++ b/example/12_reduce/README.md @@ -0,0 +1,41 @@ +# Instructions for ```example_reduce_blockwise``` + +## Run ```example_reduce_blockwise``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 +``` + +Result +``` +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> +``` + +# Instructions for ```example_reduce_blockwise_two_call``` + +## Run ```example_reduce_blockwise_two_call``` +```bash +#arg1: verification (0=no, 1=yes( +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise_two_call 1 2 1 +``` + +Result +``` +./bin/example_reduce_blockwise_two_call 1 2 1 +launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> +``` diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp new file mode 100644 index 00000000..5f127608 --- /dev/null +++ b/example/12_reduce/reduce_blockwise.cpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_reduction.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename reduce_binary_operator::opType; +using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; +using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +using DeviceReduceInstance = DeviceReduceMultiBlock; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {16, 64, 32, 960}; + std::vector scales = {1.0f, 0.0f}; + + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +int main(int argc, char* argv[]) +{ + const std::vector reduceDims{0, 1, 2}; + const std::vector invariantDims{3}; + + SimpleAppArgs args; + + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + // if input is half type, no reason to use float for indiced reduction operation and must use + // float for non-indiced reduction operation for accuracy + constexpr bool invalid_reduce_1 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // if input is float type, no reason to use double for indiced reduction operation + constexpr bool invalid_reduce_2 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // indices option can only be used when it is really needed + constexpr bool invalid_reduce_3 = (!op_support_indices && OutputIndex); + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3); + + if constexpr(invalid_reduce) + std::cout << "Reduction setting is not supported, exiting!" << std::endl; + + Tensor in(args.inLengths); + + std::vector outLengths; + + if(invariantDims.empty()) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(args.inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + float alpha = args.scales[0]; + float beta = args.scales[1]; + + std::size_t num_thread = 1; + + if(args.do_verification) + { + switch(args.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; + + DeviceMem out_index_dev(indicesSizeInBytes); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + if(args.do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); + }; + + std::vector i_inLengths; + std::vector i_inStrides; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths, + i_inStrides, + i_outLengths, + i_outStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_index_dev.GetDeviceBuffer(), + in_elementwise_op, + acc_elementwise_op); + + if(!reduce.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + std::string reduce_name = reduce.GetTypeString(); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) + + invariant_total_length * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name + << std::endl; + + bool pass = true; + + if(args.do_verification) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + + if(OutputIndex) + { + out_index_dev.FromDevice(out_indices.mData.data()); + pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData); + }; + }; + + return (pass ? 0 : 1); +} diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp new file mode 100644 index 00000000..27f89832 --- /dev/null +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_reduction.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InOutDataType = ck::half_t; +using InOutDataType = ck::half_t; +using AccDataType = float; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename reduce_binary_operator::opType; +using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; +using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +using PassThroughOp = tensor_operation::element_wise::PassThrough; + +using DeviceReduceInstance_1 = DeviceReduceMultiBlock; + +using DeviceReduceInstance_2 = DeviceReduceMultiBlock; + +static bool do_verify; +static int init_method; +static float alpha; +static float beta; +static bool time_kernel; + +int main(int argc, char* argv[]) +{ + // used by the device reduction + const std::vector reduceDims_1 = {4}; + const std::vector invariantDims_1 = {0, 1, 2, 3}; + + const std::vector reduceDims_2 = {3}; + const std::vector invariantDims_2 = {0, 1, 2}; + + // used by the host reduction + const std::vector reduceDims = {3, 4}; + const std::vector invariantDims = {0, 1, 2}; + + const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + + // input lengths of the second reduction, which is also the output lengths of the first + // reduction + const std::vector inLengths_2 = {64, 320, 80, 4}; + + const std::vector outLengths = {64, 320, 80}; + + if(argc == 1) + { + do_verify = true; + init_method = 2; + time_kernel = true; + } + else if(argc == 4) + { + do_verify = static_cast(argv[1]); + init_method = atoi(argv[2]); + time_kernel = static_cast(atoi(argv[3])); + } + else + { + std::ostringstream ostr; + + ostr << "Wrong parameter! " << std::endl + << "Usage: " << argv[0] << "[verify 0/1] init_method time_kernel" << std::endl; + + throw std::runtime_error(ostr.str()); + }; + + alpha = 1.0f; + beta = 0.0f; + + Tensor in_1(inLengths_1); + + Tensor out_ref(outLengths); + Tensor in_2(inLengths_2); // also the output tensor of the first reduction + Tensor out(outLengths); + + auto inStrides_1 = in_1.mDesc.GetStrides(); + auto inStrides_2 = in_2.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in_1.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verify) + { + switch(init_method) + { + case 0: break; + case 1: + in_1.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in_1.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in_1.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpace()); + DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpace()); + + in_1_dev.ToDevice(in_1.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + if(do_verify) + { + ReductionHost + hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in_1.mData.data(), + beta, + out_ref.mData.data(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + }; + + std::vector i_inLengths_1; + std::vector i_inStrides_1; + std::vector i_inLengths_2; + std::vector i_inStrides_2; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths_1.assign(inLengths_1.begin(), inLengths_1.end()); + i_inStrides_1.assign(inStrides_1.begin(), inStrides_1.end()); + i_inLengths_2.assign(inLengths_2.begin(), inLengths_2.end()); + i_inStrides_2.assign(inStrides_2.begin(), inStrides_2.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + auto reduce_1 = DeviceReduceInstance_1{}; + + auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1, + i_inStrides_1, + i_inLengths_2, + i_inStrides_2, + reduceDims_1, + 1.0f, + 0.0f, + in_1_dev.GetDeviceBuffer(), + nullptr, + in_2_dev.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + PassThroughOp{}); + + if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_1 = reduce_1.MakeInvokerPointer(); + + auto reduce_2 = DeviceReduceInstance_2{}; + + auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2, + i_inStrides_2, + i_outLengths, + i_outStrides, + reduceDims_2, + alpha, + beta, + in_2_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + nullptr, + PassThroughOp{}, + acc_elementwise_op); + + if(!reduce_2.IsSupportedArgument(argument_ptr_2.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_2 = reduce_2.MakeInvokerPointer(); + + float avg_time_1 = invoker_ptr_1->Run(argument_ptr_1.get(), StreamConfig{nullptr, time_kernel}); + float avg_time_2 = invoker_ptr_2->Run(argument_ptr_2.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / (avg_time_1 + avg_time_2); + + std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, " + << reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verify) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + }; + + return (pass ? 0 : 1); +} diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt new file mode 100644 index 00000000..db09c033 --- /dev/null +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) + diff --git a/example/13_pool2d_fwd/README.md b/example/13_pool2d_fwd/README.md new file mode 100644 index 00000000..9b017734 --- /dev/null +++ b/example/13_pool2d_fwd/README.md @@ -0,0 +1,41 @@ +# Instructions for ```example_pool2d_fwd``` Examples + +## Run ```example_pool2d_fwd_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_pool2d_fwd_fp16 1 1 1 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} +launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.397436 ms, 1.44252 TFlops, 783.713 GB/s +``` + +## Run ```example_pool2d_fwd_fp32``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_pool2d_fwd_fp32 1 1 1 +``` + + +Result +``` +./bin/example_pool2d_fwd_fp32 1 1 1 +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} +launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.01823 ms, 0.563045 TFlops, 611.8 GB/s +``` diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp new file mode 100644 index 00000000..b7ef1dd9 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +template +static void pool_host_verify(const Tensor& in, + Tensor& out, + Tensor& out_indices, + const std::array& window_spatial_lengths, + const std::array& window_strides, + const std::array& in_left_pads, + const std::array& /*in_right_pads*/) +{ + const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + using ReduceOperation = typename ck::reduce_binary_operator::opType; + + auto elementwise_ops = + ck::reduce_unary_operator::GetElementwiseOperator(reduceLength); + + auto in_elementwise_op = std::get<0>(elementwise_ops); + auto acc_elementwise_op = std::get<1>(elementwise_ops); + + if constexpr(!OutputIndex) + { + using Accumulation = + ck::detail::AccumulateWithNanCheck; + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < static_cast(in.mDesc.GetLengths()[2]) && + wi >= 0 && wi < static_cast(in.mDesc.GetLengths()[3])) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + out(n, c, ho, wo) = accuVal; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + IndexDataType currIndex = y * window_spatial_lengths[1] + x; + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + out(n, c, ho, wo) = accuVal; + out_indices(n, c, ho, wo) = accuIndex; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + }; +} + +template +bool pool_test(bool do_verification, + int init_method, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Y, + ck::index_t X, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< + InDataType, // InDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + ReduceOpId, + OutputIndex, + 64, // BlockSize + 64, // ReduceMThreadClusterSize + 1, // ReduceKThreadClusterSize + 4, // ReduceMThreadSliceSize + 1, // ReduceKThreadSliceSize + 4>; // InSrcOutDstVectorSize + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + const std::array window_spatial_lengths{{Y, X}}; + const std::array window_strides{{window_stride_h, window_stride_w}}; + const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_host( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_device( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); break; + case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_ho_wo_device.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + auto pool = DevicePoolFwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = pool.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + N, + C, + std::array{{Hi, Wi}}, + std::array{{Y, X}}, + std::array{{Ho, Wo}}, + window_strides, + input_left_pads, + input_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + pool_host_verify(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, + out_indices_n_c_ho_wo_host.mData); + }; + } + + return (pass); +}; diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp new file mode 100644 index 00000000..659f3251 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "pool2d_fwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main(int argc, char* argv[]) +{ + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + bool pass = pool_test(do_verification, + init_method, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp new file mode 100644 index 00000000..f47c7ff1 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "pool2d_fwd_common.hpp" + +using InDataType = float; +using OutDataType = float; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main(int argc, char* argv[]) +{ + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + bool pass = pool_test(do_verification, + init_method, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt b/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt new file mode 100644 index 00000000..0f5b8e1b --- /dev/null +++ b/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_requant_relu_requant_int8 gemm_xdl_requant_relu_requant_int8.cpp) \ No newline at end of file diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp new file mode 100644 index 00000000..a29254da --- /dev/null +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct RequantReluRequant +{ + // FIXME: We just need one scale for Relu / Leaky Relu / PRelu + RequantReluRequant(float scaleGemm, float scaleRelu) + : scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) + { + } + + __host__ __device__ constexpr void operator()(float& y, const float& x) const + { + float gemm_requant = scaleGemm_ * x; + float relu = gemm_requant > 0 ? gemm_requant : 0; + float relu_requant = scaleRelu_ * relu; + y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant; + } + + float scaleGemm_; + float scaleRelu_; +}; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout, + BLayout, // typename BLayout, + CLayout, // typename CLayout, + ADataType, // typename ADataType, + BDataType, // typename BDataType, + CDataType, // typename CDataType, + AccDataType, // typename GemmAccDataType, + CShuffleDataType, // typename CShuffleDataType, + PassThrough, // typename AElementwiseOperation, + PassThrough, // typename BElementwiseOperation, + RequantReluRequant, // typename CElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // index_t NumGemmKPrefetchStage, + 256, // index_t BlockSize, + 256, // index_t MPerBlock, + 128, // index_t NPerBlock, + 64, // index_t KPerBlock, + 16, // index_t AK1, + 16, // index_t BK1, + 32, // index_t MPerXDL, + 32, // index_t NPerXDL, + 4, // index_t MXdlPerWave, + 2, // index_t NXdlPerWave, + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + float scale_gemm = 0.03; + float scale_relu = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = RequantReluRequant{scale_gemm, scale_relu}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000..a8cac069 --- /dev/null +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) diff --git a/example/15_grouped_gemm/README.md b/example/15_grouped_gemm/README.md new file mode 100644 index 00000000..c83b23e0 --- /dev/null +++ b/example/15_grouped_gemm/README.md @@ -0,0 +1,25 @@ +# Instructions for ```example_grouped_gemm_xdl``` + +## Run ```example_grouped_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_grouped_gemm_xdl_fp16 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} +gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} +gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} +gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} +group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} +group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} +group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} +group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} +launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> +``` diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp new file mode 100644 index 00000000..5bc5316f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto GemmMNPadding = +// ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + int group_count = rand() % 16 + 1; + + // GEMM shape + std::vector gemm_shapes; + std::vector p_a, p_b; + std::vector p_c; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = 256 + 256 * i; + int N = 128 + 128 * i; + int K = 64 + 64 * i; + + gemm_shapes.push_back({M, N, K, K, K, N}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + ; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace())); + b_tensors_device.emplace_back( + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = + gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); + } + } + + return pass ? 0 : 1; +} diff --git a/example/16_gemm_reduce/CMakeLists.txt b/example/16_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000..90ff5897 --- /dev/null +++ b/example/16_gemm_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp) +add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp) diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp new file mode 100644 index 00000000..24797ecc --- /dev/null +++ b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using ReduceDataType = F64; +using ReducePtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOps = ck::Tuple; +using ReduceElementOps = ck::Tuple; +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + + std::cout << "gemm + reduceMax Perf: " << gemm_reduce_time << " ms, " << tflops << " TFlops, " + << gemm_gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem reduce_device_buf(sizeof(ReduceDataType) * + reduce_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}]; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + std::array reduce_element_ops = {&reduce_element_op}; + std::array p_reduces = {reduce_device_buf.GetDeviceBuffer()}; + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_element_ops, + reduce_element_ops); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // [CAUSION]: launch_and_time_kernel will not initialize D. + // If we evaluate kernel multiple time but without initialize D. Verification will fail + reduce_device_buf.SetValue(ck::NumericLimits::Lowest()); + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce_device_buf.FromDevice(reduce_m_device_result.mData.data()); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto reduce_op = ReduceOps{}[ck::Number<0>{}]; + + for(int m = 0; m < M; ++m) + { + ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType curr_val = + ck::type_convert(c_m_n_host_result(m, n)); + reduce_op(reduce_acc, curr_val); + }; + + reduce_m_host_result(m) = reduce_acc; + } + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(reduce_m_device_result.mData, + reduce_m_host_result.mData, + "Error: Incorrect results d", + 1e-3, + 1e-3); + } + + if(time_kernel) + { + float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + DumpGemmLayerNormPerf( + gemm_reduceMax_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp new file mode 100644 index 00000000..a31f8549 --- /dev/null +++ b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOp0 = ck::reduce::Add; +using ReduceOp1 = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + + std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init reducetion buffer to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = ReduceOp0{}; + auto reduce1_op = ReduceOp1{}; + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto c_val = ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType square_c_val; + square(square_c_val, c_val); + + reduce0_op(reduce0_acc, c_val); + reduce1_op(reduce1_acc, square_c_val); + } + + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); + } + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(reduce0_m_device_result.mData, + reduce0_m_host_result.mData, + "Error: Incorrect results d0", + 1e-4, + 1e-5) && + ck::utils::check_err(reduce1_m_device_result.mData, + reduce1_m_host_result.mData, + "Error: Incorrect results d1", + 1e-3, + 1e-5); + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + DumpGemmLayerNormPerf(ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/17_convnd_bwd_data_xdl/CMakeLists.txt b/example/17_convnd_bwd_data_xdl/CMakeLists.txt new file mode 100644 index 00000000..963f3117 --- /dev/null +++ b/example/17_convnd_bwd_data_xdl/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) +target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util) diff --git a/example/17_convnd_bwd_data_xdl/README.md b/example/17_convnd_bwd_data_xdl/README.md new file mode 100644 index 00000000..b5c8281e --- /dev/null +++ b/example/17_convnd_bwd_data_xdl/README.md @@ -0,0 +1,47 @@ +# Instructions for ```example_convnd_bwd_data_xdl``` + +## Run ```example_example_convnd_bwd_data_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: num_dim_spatial(1|2|3) +#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx +./bin/example_convnd_bwd_data_xdl 0 1 5 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128} +wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{128, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{32, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s +``` diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp new file mode 100644 index 00000000..2fd15c1d --- /dev/null +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +using DeviceConvBwdDataBasePtr = + ck::tensor_operation::device::DeviceConvBwdDataPtr; + +template +using DeviceConvNDBwdDataInstance = ck::tensor_operation::device:: + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +template +using ReferenceConvBwdDataInstance = + ck::tensor_operation::host::ReferenceConvBwdData; + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc > 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi_host_result( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor in_n_c_hi_wi_device_result( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{-0.2, 0.2}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.2, 0.2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + // reset input to zero + in_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + return ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdDataInstance<3>(); + return verify_f(ref_conv); + } + case 2: { + auto ref_conv = ReferenceConvBwdDataInstance<2>(); + return verify_f(ref_conv); + } + case 1: { + auto ref_conv = ReferenceConvBwdDataInstance<1>(); + return verify_f(ref_conv); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000..99fc0043 --- /dev/null +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp new file mode 100644 index 00000000..e1db19e2 --- /dev/null +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using ReduceAccDataType = F32; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOp0 = ck::reduce::Add; +using ReduceOp1 = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 2048; + ck::index_t N = 1920; + ck::index_t K = 2048; + + ck::index_t StrideA = 2048; + ck::index_t StrideB = 2048; + ck::index_t StrideC = 1920; + + ck::index_t BatchCount = 4; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + + BatchCount = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); + exit(0); + } + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + d1_g_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&passthrough, &passthrough}; + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + // do GEMM + auto batched_gemm = DeviceBatchedGemmReduceInstance{}; + auto invoker = batched_gemm.MakeInvoker(); + auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops, + BatchCount); + + if(!batched_gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << batched_gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = ReduceOp0{}; + auto reduce1_op = ReduceOp1{}; + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto c_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; + + UnaryIdenticElementOp{}(d0_val, c_val); + UnarySquareElementOp{}(d1_val, c_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + d0_g_m_host_result(batch, m) = ck::type_convert(reduce0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(reduce1_acc); + } + } + + pass = ck::utils::check_err(c_g_m_n_host_result.mData, + c_g_m_n_device_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d0_g_m_device_result.mData, + d0_g_m_host_result.mData, + "Error: Incorrect results! D0", + 1e-4, + 1e-5) && + ck::utils::check_err(d1_g_m_device_result.mData, + d1_g_m_host_result.mData, + "Error: Incorrect results! D1", + 1e-3, + 1e-5); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt new file mode 100644 index 00000000..39646e0a --- /dev/null +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) +add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) +add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) +add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) \ No newline at end of file diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp new file mode 100644 index 00000000..b21d314d --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_broadcast2D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + ComputeDataType Amn = ck::type_convert(A(m, n)); + ComputeDataType Cmn = 0; + if constexpr(broadcastDim == 0) + { + ComputeDataType Bn = ck::type_convert(B(n)); + functor(Cmn, Amn, Bn); + } + else + { + ComputeDataType Bm = ck::type_convert(B(m)); + functor(Cmn, Amn, Bm); + } + C(m, n) = ck::type_convert(Cmn); + } + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + }; + + Tensor a_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor b_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + a_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace()); + DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + + a_m_n_device_buf.ToDevice(a_m_n.mData.data()); + b_n_device_buf.ToDevice(b_n.mData.data()); + + std::array input = {a_m_n_device_buf.GetDeviceBuffer(), + b_n_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_n_device_buf.GetDeviceBuffer()}; + + std::vector a_strides = {Stride, 1}; + std::vector b_strides = {0, 1}; + std::vector c_strides = {Stride, 1}; + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); + Tensor host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + host_broadcast2D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add, + 0>(host_c_m_n, a_m_n, b_n, M, N, Add{}); + + pass &= ck::utils::check_err( + c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp new file mode 100644 index 00000000..43762ad0 --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_broadcast3D_am_bmnk(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t m = 0; m < shape[0]; ++m) + for(std::size_t n = 0; n < shape[1]; ++n) + for(std::size_t k = 0; k < shape[2]; ++k) + { + ComputeDataType a_val = ck::type_convert(A(m)); + ComputeDataType b_val = ck::type_convert(B(m, n, k)); + ComputeDataType c_val = 0; + functor(c_val, a_val, b_val); + C(m, n, k) = ck::type_convert(c_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector mnk = {4, 16, 32}; + ck::index_t M = mnk[0]; + + Tensor a_m({M}); + Tensor b_m_n_k(mnk); + Tensor c_m_n_k(mnk); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); + + std::array input = {a_m_device_buf.GetDeviceBuffer(), + b_m_n_k_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_n_k_device_buf.GetDeviceBuffer()}; + + std::vector a_strides = {1, 0, 0}; + std::vector b_strides{b_m_n_k.mDesc.GetStrides().begin(), + b_m_n_k.mDesc.GetStrides().end()}; + std::vector c_strides{c_m_n_k.mDesc.GetStrides().begin(), + c_m_n_k.mDesc.GetStrides().end()}; + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = + broadcastAdd.MakeArgumentPointer(input, + output, + std::vector{mnk.begin(), mnk.end()}, + {a_strides, b_strides}, + {c_strides}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data()); + Tensor host_c_m_n_k(mnk); + + host_broadcast3D_am_bmnk, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{}); + + pass &= ck::utils::check_err( + c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp new file mode 100644 index 00000000..d784aec7 --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_elementwise1D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + ComputeDataType Am = ck::type_convert(A(m)); + ComputeDataType Bm = ck::type_convert(B(m)); + ComputeDataType Cm = 0; + functor(Cm, Am, Bm); + C(m) = ck::type_convert(Cm); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + Tensor b_m(f_host_tensor_descriptor1d(M, 1)); + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_device_buf.ToDevice(b_m.mData.data()); + + std::array input = {a_m_device_buf.GetDeviceBuffer(), + b_m_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_device_buf.GetDeviceBuffer()}; + + std::vector a_strides = {1}; + std::vector b_strides = {1}; + std::vector c_strides = {1}; + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c_m, a_m, b_m, M, Add{}); + + pass &= ck::utils::check_err( + c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp new file mode 100644 index 00000000..7a585cd7 --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_elementwise4D(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t n = 0; n < shape[0]; ++n) + for(std::size_t c = 0; c < shape[1]; ++c) + for(std::size_t h = 0; h < shape[2]; ++h) + for(std::size_t w = 0; w < shape[3]; ++w) + { + ComputeDataType a_val = ck::type_convert(A(n, c, h, w)); + ComputeDataType b_val = ck::type_convert(B(n, c, h, w)); + ComputeDataType c_val = 0; + functor(c_val, a_val, b_val); + C(n, c, h, w) = ck::type_convert(c_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector nchw = {4, 16, 32, 32}; + + Tensor a(nchw); + Tensor b(nchw); + Tensor c(nchw); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a.mData.data()); + b_device_buf.ToDevice(b.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer()}; + std::array output = {c_device_buf.GetDeviceBuffer()}; + + std::vector a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}; + std::vector b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}; + std::vector c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()}; + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = + broadcastAdd.MakeArgumentPointer(input, + output, + std::vector{nchw.begin(), nchw.end()}, + {{a_strides}, b_strides}, + {c_strides}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c.mData.data()); + Tensor host_c(nchw); + + host_elementwise4D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c, a, b, nchw, Add{}); + + pass &= + ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt new file mode 100644 index 00000000..66fdef62 --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp) +add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp) +target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) +target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util) \ No newline at end of file diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp new file mode 100644 index 00000000..b4a78733 --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k \n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 1; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + float ave_time = 0.f; + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + return verify_f(ref_conv); + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + return verify_f(ref_conv); + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + return verify_f(ref_conv); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp new file mode 100644 index 00000000..5334cfac --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/device_unary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert; + +using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device:: + DeviceUnaryElementwise; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + AccDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +template +void host_elementwise(HostTensorB& B, + const HostTensorA& A, + const std::vector& shape, + Functor functor) +{ + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl; + for(std::size_t n = 0; n < tensor_size; ++n) + { + B.mData[n] = functor(A.mData[n]); + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k : in this example split-k must be larger than 1\n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 2; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + if(split_k <= 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); + if(bwd_weight_workspace_size <= 0) + { + print_use_msg(); + exit(1); + } + + float conv_ave_time = 0.f; + + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer()); + + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / conv_ave_time; + + float gb_per_sec = num_btype / 1.E6 / conv_ave_time; + + std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt new file mode 100644 index 00000000..78d3a5d0 --- /dev/null +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp) diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp new file mode 100644 index 00000000..531777b1 --- /dev/null +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp @@ -0,0 +1,432 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using BiasDataType = F32; +using D0DataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::Relu; +using D0ElementOp = PassThrough; +using ReduceSumOp = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = + ck::tensor_operation::device::Device5AryElementwise; // scalarPerVector: LayerNorm_out + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +template +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& bias_n, + const Tensor& c1_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + A_functor a_element_op, + B_functor b_element_op, + C_functor c_element_op, + C1_functor c1_element_op, + int M, + int N) +{ + + int StrideC = N; + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + // c = activation(c + bias) + c1_functor(c1) + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + AccDataType acc = ck::type_convert(c_m_n(m, n)) + + ck::type_convert(bias_n(n)); + + AccDataType c1 = ck::type_convert(c1_m_n(m, n)); + + c_element_op(acc, acc); + c1_element_op(c1, c1); + acc += c1; + c_m_n(m, n) = ck::type_convert(acc); + } + + // reduce_mean and reduce_square_mean + auto reduceSumOpInst = ReduceSumOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = reduceSumOpInst.GetIdentityValue(); + auto square_mean_acc = reduceSumOpInst.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + AccDataType c_val = ck::type_convert(c_m_n(m, n)); + AccDataType square_c_val = 0; + UnarySquareElementOp{}(square_c_val, c_val); + + reduceSumOpInst(mean_acc, c_val); + reduceSumOpInst(square_mean_acc, square_c_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(square_mean_acc, square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + AccDataType out_acc = 0; + layerNormInst(out_acc, + ck::type_convert(c_m_n(m, n)), + ck::type_convert(mean_m(m)), + ck::type_convert(meanSquare_m(m)), + ck::type_convert(gamma_n(n)), + ck::type_convert(beta_n(n))); + out_m_n(m, n) = ck::type_convert(out_acc); + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + ck::index_t StrideD0 = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + bias_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace()); + DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * + reduceMeanSquare_m.mDesc.GetElementSpace()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + d0_device_buf.ToDevice(c1_m_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto d_element_op = D0ElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + std::array p_reduces = {reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer()}; + + // Prepare GEMM, reduce_mean, reduce_mean_square + auto gemmReduce = DeviceGemmBiasAddReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + gemm_element_ops, + {&d_element_op}, + reduce_in_element_ops, + reduce_out_element_ops); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + reduceMean_device_buf.SetZero(); + reduceMeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + std::array input = {c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer()}; + std::array output = {layerNorm_device_buf.GetDeviceBuffer()}; + + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument = normalize.MakeArgument(input, + output, + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument)) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "Device5AryElementwise instance, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + bias_n, + c1_m_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + d_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n.mData, + host_layerNorm_m_n.mData, + "Error: Incorrect results layerNorm_m_n", + 1e-2, + 1e-2); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp new file mode 100644 index 00000000..6f1984e1 --- /dev/null +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp @@ -0,0 +1,390 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceSumOp = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = + ck::tensor_operation::device::Device5AryElementwise; // scalarPerVector: LayerNorm_out + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +template +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& gamma_n, + const Tensor& beta_n, + A_functor a_element_op, + B_functor b_element_op, + C_functor c_element_op, + int M, + int N) +{ + using out_type = ck::remove_reference_t; + + int StrideC = N; + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + // reduce_mean and reduce_square_mean + auto reduceSumOpInst = ReduceSumOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = reduceSumOpInst.GetIdentityValue(); + auto square_mean_acc = reduceSumOpInst.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto c_val = ck::type_convert(c_m_n(m, n)); + auto square_c_val = reduceSumOpInst.GetIdentityValue(); + + UnarySquareElementOp{}(square_c_val, c_val); + + reduceSumOpInst(mean_acc, c_val); + reduceSumOpInst(square_mean_acc, square_c_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(square_mean_acc, square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + float out_f32 = 0; + layerNormInst(out_f32, + static_cast(c_m_n(m, n)), + static_cast(mean_m(m)), + static_cast(meanSquare_m(m)), + static_cast(gamma_n(n)), + static_cast(beta_n(n))); + out_m_n(m, n) = static_cast(out_f32); + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_btye / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * + reduceMeanSquare_m.mDesc.GetElementSpace()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + std::array p_reduces = {reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer()}; + + // Prepare GEMM, reduce_mean, reduce_mean_square + auto gemmReduce = DeviceGemmReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + reduceMean_device_buf.SetZero(); + reduceMeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + std::array input = {c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer()}; + std::array output = {layerNorm_device_buf.GetDeviceBuffer()}; + + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument = normalize.MakeArgument(input, + output, + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument)) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "Device5AryElementwise instance, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n.mData, + host_layerNorm_m_n.mData, + "Error: Incorrect results d1", + 1e-3, + 1e-3); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp new file mode 100644 index 00000000..f72aa62b --- /dev/null +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel +// +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using C0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +struct Relu +{ + template + __host__ __device__ void operator()(OutT& y, const InT& x) const + { + y = x > 0 ? x : 0; + } +}; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +// Elementwise operation that operates on the output of matrix multiplication +// i.e., AccElementOp(A * B + bias) +using AccElementOp = Relu; +// Elementwise operation that operates on the output of layer normalization +using CElementOp = Relu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle +//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadCopy| +//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>; +// clang-format on + +using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 128; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 128; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_n_bias(HostTensorDescriptor(std::vector({size_t(N)}))); + Tensor c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_n_gamma(HostTensorDescriptor(std::vector({size_t(N)}))); + Tensor c0_n_beta(HostTensorDescriptor(std::vector({size_t(N)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl; + std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl; + std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl; + std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + c0_n_bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_m_n_add.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n_gamma.GenerateTensorValue(GeneratorTensor_2{0, 2}); + c0_n_beta.GenerateTensorValue(GeneratorTensor_2{0, 5}); + c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1{0}); + acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1{0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace()); + DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace()); + DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace()); + DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c0_bias_buf.ToDevice(c0_n_bias.mData.data()); + c0_add_buf.ToDevice(c0_m_n_add.mData.data()); + c0_gamma_buf.ToDevice(c0_n_gamma.mData.data()); + c0_beta_buf.ToDevice(c0_n_beta.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto acc_element_op = AccElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_add_buf.GetDeviceBuffer()), + static_cast(c0_bias_buf.GetDeviceBuffer()), + static_cast(c0_gamma_buf.GetDeviceBuffer()), + static_cast(c0_beta_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div, + // excluding reduction steps + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N; + // extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN) + std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + auto ref_gemm = ReferenceInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n_bias, + c0_m_n_add, + c0_n_gamma, + c0_n_beta, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + if constexpr(std::is_same::value) + { + pass &= ck::utils::check_err( + c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c"); + } + else if constexpr(std::is_same::value) + { + pass &= ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c", + 1e-2, + 1e-2); + } + } + return pass ? 0 : 1; +} diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt new file mode 100644 index 00000000..048df3bb --- /dev/null +++ b/example/22_cgemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp new file mode 100644 index 00000000..80c5d25e --- /dev/null +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using AccDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // CGEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; + std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; + std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl; + std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl; + std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl; + std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k_real.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a_m_k_imag.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n_real.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n_imag.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_m_k_real.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k_imag.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_real.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_imag.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + auto cgemm = DeviceCGemmInstance{}; + + DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace()); + DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace()); + DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace()); + DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace()); + DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * + c_m_n_real_device_result.mDesc.GetElementSpace()); + DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * + c_m_n_imag_device_result.mDesc.GetElementSpace()); + DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC)); + + a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); + a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); + b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data()); + b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto invoker = cgemm.MakeInvoker(); + auto argument = + cgemm.MakeArgument(static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), + static_cast(workspace_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!cgemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_cgemm with the specified compilation parameters does " + "not support this CGEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(8) * M * N * K; + std::size_t num_btype = + std::size_t(2) * + (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << cgemm.GetTypeString() << std::endl; + + c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data()); + c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n_real_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + auto ref_cgemm = ReferenceCGemmInstance{}; + auto ref_invoker = ref_cgemm.MakeInvoker(); + + auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real_host_result, + c_m_n_imag_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + ck::utils::check_err(c_m_n_real_device_result.mData, + c_m_n_real_host_result.mData, + "Verification error: incorrect results in real part!", + 1e-2f, + 1e-1f); + ck::utils::check_err(c_m_n_imag_device_result.mData, + c_m_n_imag_host_result.mData, + "Verification error: incorrect results in imaginary part!", + 1e-2f, + 1e-1f); + } + + return 0; +} diff --git a/example/23_softmax/CMakeLists.txt b/example/23_softmax/CMakeLists.txt new file mode 100644 index 00000000..dafe6552 --- /dev/null +++ b/example/23_softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_softmax_blockwise softmax_blockwise.cpp) \ No newline at end of file diff --git a/example/23_softmax/README.md b/example/23_softmax/README.md new file mode 100644 index 00000000..37c43e9b --- /dev/null +++ b/example/23_softmax/README.md @@ -0,0 +1,18 @@ +# Instructions for ```example_softmax_blockwise``` + +## Run ```example_softmax_blockwise``` +```bash +# -D : input 3-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +example_softmax_blockwise -D 4,128,2048 -v 1 1 1 +``` + +Result +``` +launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8> +``` diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp new file mode 100644 index 00000000..388cbd8c --- /dev/null +++ b/example/23_softmax/softmax_blockwise.cpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +constexpr int Rank = 3; +constexpr int NumReduceDim = 1; + +using DeviceInstance = DeviceSoftmax; // OutScalarPerVector + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {8, 128, 2048}; + std::vector scales = {2.0f, 2.0f}; + + bool do_verification = true; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +int main(int argc, char* argv[]) +{ + // Example: batched gemm C[G, M, N] applies max/sum reduction along N internally + const std::vector invariantDims{0, 1}; + const std::vector reduceDims{2}; + + SimpleAppArgs args; + + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; + + Tensor in(args.inLengths); + Tensor out_ref(args.inLengths); + Tensor out(args.inLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + AccDataType alpha = args.scales[0]; + AccDataType beta = args.scales[1]; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + std::size_t num_thread = 1; + + if(args.do_verification) + { + switch(args.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + // std::cout << "beta = " << beta << std::endl; + // LogRangeAsType(std::cout << "tensor in: " , in.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "tensor prior out: " , out.mData, ",") << std::endl; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + if(args.do_verification) + { + using ReferenceInstance = + tensor_operation::host::ReferenceSoftmax; + ReferenceInstance ref; + auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims); + auto invoker = ref.MakeInvoker(); + invoker.Run(ref_arg); + // LogRangeAsType(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl; + }; + + std::vector i_inLengths; + std::vector i_inStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + + auto device_instance = DeviceInstance{}; + + auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, + i_inStrides, + reduceDims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + return 1; + }; + + std::string instance_name = device_instance.GetTypeString(); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + + bool pass = true; + if(args.do_verification) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + out_dev.FromDevice(out.mData.data()); + // LogRangeAsType(std::cout << "tensor out: " , out.mData, ",") << std::endl; + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + }; + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); + + std::size_t num_bytes = + in.mDesc.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name + << std::endl; + + return (pass ? 0 : 1); +} diff --git a/example/24_batched_gemm_c_permute/CMakeLists.txt b/example/24_batched_gemm_c_permute/CMakeLists.txt new file mode 100644 index 00000000..79c612d0 --- /dev/null +++ b/example/24_batched_gemm_c_permute/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_batched_gemm_c_permute_xdl_fp16 batched_gemm_c_permute_xdl_fp16.cpp) + diff --git a/example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp b/example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp new file mode 100644 index 00000000..aad6f526 --- /dev/null +++ b/example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp @@ -0,0 +1,245 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl +//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; + < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + const int M = 88; + const int N = 64; + const int K = 88; + + const int stride_A = K; + const int stride_B = K; + + const int G0 = 1024; + const int G1 = 10; + + const int batch_count = G0 * G1; + + // output layout - [G0, M, G1, N] + const int stride_G0 = M * G1 * N; + const int stride_G1 = N; + const int stride_M = G1 * N; + const int stride_N = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + // GEMM shape + ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{ + G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N}; + + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count_, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count_, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); + + auto f_host_c_tensor_descriptor = [](std::size_t G0_, + std::size_t G1_, + std::size_t M_, + std::size_t N_, + std::size_t stride_G0_, + std::size_t stride_G1_, + std::size_t stride_M_, + std::size_t stride_N_) { + return HostTensorDescriptor( + std::vector({G0_, G1_, M_, N_}), + std::vector({stride_G0_, stride_G1_, stride_M_, stride_N_})); + }; + + Tensor c_g0_g1_m_n_host_result( + f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); + + Tensor c_g0_g1_m_n_device_result( + f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + stride_A, + stride_B, + batched_gemm_c_permute_desc, + a_element_op, + b_element_op, + c_element_op, + batch_count); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(CDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data()); + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + Tensor c_g_m_n_host_result = HostTensorDescriptor( + std::vector({batch_count, M, N}), std::vector({M * N, N, 1})); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int g0 = 0; g0 < G0; g0++) + { + for(int g1 = 0; g1 < G1; g1++) + { + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + int g = g0 * G1 + g1; + c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n); + } + } + } + } + + pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData, + c_g0_g1_m_n_device_result.mData, + "Error: Incorrect results c"); + } + + return pass ? 0 : 1; +} diff --git a/example/25_gemm_bias_c_permute/CMakeLists.txt b/example/25_gemm_bias_c_permute/CMakeLists.txt new file mode 100644 index 00000000..29b1d94b --- /dev/null +++ b/example/25_gemm_bias_c_permute/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp) diff --git a/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp b/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp new file mode 100644 index 00000000..7c242cc7 --- /dev/null +++ b/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl +//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t M0 = 4; + ck::index_t M1 = 32; + ck::index_t M2 = 128; + ck::index_t N0 = 16; + ck::index_t N1 = 256; + + // GEMM shape + ck::index_t M = M0 * M1 * M2; + ck::index_t N = N0 * N1; + ck::index_t K = 128; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + +#if 1 + // E = [M0, N0, M1, N1, M2] + ck::index_t stride_E_M0 = N0 * M1 * N1 * M2; + ck::index_t stride_E_M1 = N1 * M2; + ck::index_t stride_E_M2 = 1; + ck::index_t stride_E_N0 = M1 * N1 * M2; + ck::index_t stride_E_N1 = M2; + + // D = [0, N0, 0, N1, 0] + ck::index_t stride_D_M0 = 0; + ck::index_t stride_D_M1 = 0; + ck::index_t stride_D_M2 = 0; + ck::index_t stride_D_N0 = N1; + ck::index_t stride_D_N1 = 1; +#else + // D = [0, 0, 0, N0, N1] + ck::index_t stride_D_M0 = 0; + ck::index_t stride_D_M1 = 0; + ck::index_t stride_D_M2 = 0; + ck::index_t stride_D_N0 = N1; + ck::index_t stride_D_N1 = 1; + + // E = [M0, M1, M2, N0, N1] + ck::index_t stride_E_M0 = M1 * M2 * N0 * N1; + ck::index_t stride_E_M1 = M2 * N0 * N1; + ck::index_t stride_E_M2 = N0 * N1; + ck::index_t stride_E_N0 = N1; + ck::index_t stride_E_N1 = 1; +#endif + + const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{ + M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1}; + const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{ + M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + auto f_host_de_tensor_descriptor = + [](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) { + std::size_t m0 = de_grid_desc.M0_; + std::size_t m1 = de_grid_desc.M1_; + std::size_t m2 = de_grid_desc.M2_; + std::size_t n0 = de_grid_desc.N0_; + std::size_t n1 = de_grid_desc.N1_; + std::size_t stride_m0 = de_grid_desc.stride_M0_; + std::size_t stride_m1 = de_grid_desc.stride_M1_; + std::size_t stride_m2 = de_grid_desc.stride_M2_; + std::size_t stride_n0 = de_grid_desc.stride_N0_; + std::size_t stride_n1 = de_grid_desc.stride_N1_; + return HostTensorDescriptor( + std::vector({m0, m1, m2, n0, n1}), + std::vector({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1})); + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + Tensor d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc)); + Tensor e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc)); + Tensor e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc)); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl; + std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) * + d_m0_m1_m2_n0_n1.mDesc.GetElementSpace()); + DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) * + e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), + b_k_n_device_buf.GetDeviceBuffer(), + d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), + e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m0 = 0; m0 < M0; ++m0) + for(int m1 = 0; m1 < M1; ++m1) + for(int m2 = 0; m2 < M2; ++m2) + for(int n0 = 0; n0 < N0; ++n0) + for(int n1 = 0; n1 < N1; ++n1) + { + int m = m0 * M1 * M2 + m1 * M2 + m2; + int n = n0 * N1 + n1; + + cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1), + ck::type_convert(c_m_n(m, n)), + d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1)); + } + + e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data()); + + return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData, + e_m0_m1_m2_n0_n1_host_result.mData) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt new file mode 100644 index 00000000..87f4750e --- /dev/null +++ b/example/26_contraction/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) diff --git a/example/26_contraction/README.md b/example/26_contraction/README.md new file mode 100644 index 00000000..c88d93cf --- /dev/null +++ b/example/26_contraction/README.md @@ -0,0 +1,20 @@ +# Instructions for ```example_contraction_bilinear_xdl_fp32``` + +## Run +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_contraction_bilinear_xdl_fp32 1 1 1 +``` + +Result (MI100 @ dynammic freq, 46TFlops peak FP32) +``` +a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} +b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096} +c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4> +``` diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp new file mode 100644 index 00000000..43972482 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks( + std::vector(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), + std::vector(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); + Tensor b_ns_ks( + std::vector(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()), + std::vector(b_ns_ks_strides.begin(), b_ns_ks_strides.end())); + Tensor d_ms_ns( + std::vector(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()), + std::vector(d_ms_ns_strides.begin(), d_ms_ns_strides.end())); + Tensor e_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + Tensor e_ms_ns_device_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp new file mode 100644 index 00000000..deee6338 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks( + std::vector(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), + std::vector(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); + Tensor b_ns_ks( + std::vector(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()), + std::vector(b_ns_ks_strides.begin(), b_ns_ks_strides.end())); + Tensor e_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + Tensor e_ms_ns_device_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), + e_ms_ns_lengths.begin() + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, + e_ms_ns_lengths.begin() + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, + a_ms_ks_lengths.begin() + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + std::vector(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), + std::vector(e_ms_ns_strides.begin(), e_ms_ns_strides.end())); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt new file mode 100644 index 00000000..65815b65 --- /dev/null +++ b/example/CMakeLists.txt @@ -0,0 +1,47 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + +add_custom_target(examples) + +function(add_example_executable EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) +endfunction(add_example_executable EXAMPLE_NAME) + +function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_dependencies(examples ${EXAMPLE_NAME}) +endfunction(add_example_executable_no_testing EXAMPLE_NAME) + +add_subdirectory(01_gemm) +add_subdirectory(02_gemm_bilinear) +add_subdirectory(03_gemm_bias_relu) +add_subdirectory(04_gemm_add_add_fastgelu) +add_subdirectory(06_conv2d_fwd_bias_relu) +add_subdirectory(07_conv2d_fwd_bias_relu_add) +add_subdirectory(09_convnd_fwd) +add_subdirectory(10_conv2d_bwd_data) +add_subdirectory(11_conv2d_bwd_weight) +add_subdirectory(12_reduce) +add_subdirectory(13_pool2d_fwd) +add_subdirectory(14_gemm_xdl_requant_relu_requant) +add_subdirectory(15_grouped_gemm) +add_subdirectory(16_gemm_reduce) +add_subdirectory(17_convnd_bwd_data_xdl) +add_subdirectory(18_batched_gemm_reduce) +add_subdirectory(19_binary_elementwise) +add_subdirectory(20_convnd_bwd_weight_xdl) +add_subdirectory(21_gemm_layernorm) +add_subdirectory(22_cgemm) +add_subdirectory(23_softmax) +add_subdirectory(24_batched_gemm_c_permute) +add_subdirectory(25_gemm_bias_c_permute) +add_subdirectory(26_contraction) diff --git a/external/rocm/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp deleted file mode 100644 index 52d00346..00000000 --- a/external/rocm/include/bfloat16_dev.hpp +++ /dev/null @@ -1,125 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2019 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef BFLOAT16_DEVICE_HPP -#define BFLOAT16_DEVICE_HPP - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __HIP_PLATFORM_HCC__ -#define EXECUTION_SPECIFIER __device__ -#else -#define EXECUTION_SPECIFIER -#endif // MIOPEN_BACKEND_HIP - -typedef union -{ - uint u32; - ushort2 ushortx2; - -// Composable kernels are written in HIP language. The language doesnt support -// ushort2.hi or ushort2.low. -#ifdef __HIP_PLATFORM_HCC__ - ushort ushortvec[2]; -#endif // MIOPEN_BACKEND_HIP - float f32; -} cvt_bf16_fp32_t; - -EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val) -{ - cvt_bf16_fp32_t target_val; - -#ifdef __HIP_PLATFORM_HCC__ - target_val.ushortx2 = make_ushort2(0, src_val); -#else - target_val.ushortx2 = (ushort2)(0, src_val); -#endif - - return target_val.f32; -} - -EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val) -{ - cvt_bf16_fp32_t target_val; - target_val.f32 = src_val; - // BF16 round and NaN preservation code matches - // https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h - if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN - { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - if((target_val.u32 & 0xffff) != 0) - { - target_val.u32 |= 0x10000; // Preserve signaling NaN - } - } - else - { -#ifdef MIOPEN_USE_RNE_BFLOAT16 -// When the exponent bits are not all 1s, then the value is zero, normal, -// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus -// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). -// This causes the bfloat16's mantissa to be incremented by 1 if the 16 -// least significant bits of the float mantissa are greater than 0x8000, -// or if they are equal to 0x8000 and the least significant bit of the -// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when -// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already -// has the value 0x7f, then incrementing it causes it to become 0x00 and -// the exponent is incremented by one, which is the next higher FP value -// to the unrounded bfloat16 value. When the bfloat16 value is subnormal -// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up -// to a normal value with an exponent of 0x01 and a mantissa of 0x00. -// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, -// incrementing it causes it to become an exponent of 0xFF and a mantissa -// of 0x00, which is Inf, the next higher value to the unrounded value. -#ifdef __HIP_PLATFORM_HCC__ - target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1)); -#else - target_val.u32 += - (0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even -#endif // MIOPEN_BACKEND_HIP -#endif // MIOPEN_USE_RNE_BFLOAT16 - } - -#ifdef __HIP_PLATFORM_HCC__ - return target_val.ushortvec[1]; -#else - return target_val.ushortx2.hi; -#endif // MIOPEN_BACKEND_HIP -} - -#ifdef __cplusplus -} -#endif - -#endif // BFLOAT16_DEVICE_HPP diff --git a/host/CMakeLists.txt b/host/CMakeLists.txt deleted file mode 100644 index 30cc14d8..00000000 --- a/host/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(host_tensor) -add_subdirectory(driver_offline) diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt deleted file mode 100644 index a3b36132..00000000 --- a/host/driver_offline/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -include_directories(BEFORE - include - ${PROJECT_SOURCE_DIR}/host/host_tensor/include - ${PROJECT_SOURCE_DIR}/host/solver/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation - ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform - ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver - ${PROJECT_SOURCE_DIR}/external/rocm/include -) - -set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) -set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) -set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) -set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) - -add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) -add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) -add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) -add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) - -target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) -target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) -target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) -target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/debug.hpp b/host/driver_offline/include/debug.hpp deleted file mode 100644 index 72fd0763..00000000 --- a/host/driver_offline/include/debug.hpp +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef DEBUG_HPP -#define DEBUG_HPP - -namespace debug { -namespace debug_driver_gemm_xdlops_v2r3 { - -// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping -static ck::index_t M01 = 1; -static ck::index_t N01 = 1; - -} // namespace debug_driver_gemm_xdlops_v2r3 -} // namespace debug -#endif diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 8258aa0e..00000000 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,309 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" -#include "debug.hpp" - -template -void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - I0, - I0, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - // clang-format off - constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - //clang-format on - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(out_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<1, 3, 7, 0, 2, 4, 5, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - false, // ABlockLdsExtraM - false // BBlockLdsExtraN - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 28d6226f..00000000 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,423 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: - // gemmk1 - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - - // clang-format off - constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - // clang-format on - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilda = ConvStrideH / GcdStrideDilationH; - const auto XTilda = ConvStrideW / GcdStrideDilationW; - - float ave_time = 0; - - for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) - { - for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) - { - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( - out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - i_ytilda, - i_xtilda, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - - const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0); - - if(GemmK0 != 0) - { - ave_time += driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy -#if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, -#else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, -#endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true, // CAccessOrderMRepeatNRepeat - false, // ABlockLdsExtraM - false // BBlockLdsExtraN - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - } - } - } - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp deleted file mode 100644 index d6955ec0..00000000 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp +++ /dev/null @@ -1,389 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations&, - const InLeftPads&, - const InRightPads&, - Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0>{}, // 1+: gemmm - Sequence<0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0>{}, // 1-: gemmm - Sequence<0, 0, 0>{})); // 2-: gemmk1 - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0>{}, // 1+: gemmn - Sequence<0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0>{}, // 1-: Gemmn - Sequence<0, 0, 0>{})); // 2-: Gemmk1 - - // clang-format off - constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - // clang-format on - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( - out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy -#if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, -#else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, -#endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true, // CAccessOrderMRepeatNRepeat - false, // ABlockLdsExtraM - false // BBlockLdsExtraN - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 8207e2cb..00000000 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,256 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_xdlops_v2r4.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - Tensor& wei_k_c_y_x, - const Tensor& out_n_k_ho_wo, - GridSizeType desired_grid_size, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>; - // using vector load 4, so config's wo*ho must be a multiple of 4 - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto N = in_n_c_hi_wi_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_desc.GetLength(I1); - - const auto Ho = out_n_k_ho_wo_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_desc.GetLength(I2); - const auto X = wei_k_c_y_x_desc.GetLength(I3); - - const auto GemmM = K; - const auto GemmN = Y * X * C; - const auto GemmKTotal = N * Ho * Wo; - - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); - const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; - - std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN - << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad - << std::endl; - const auto descs = - transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - GemmKBatch, - GemmKPad); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB - Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM - Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB - Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM - Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 1, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; - - const auto driver_gemm_xdlops = - driver_gemm_xdlops_v2r4, - Sequence<0, 2, 1, 3>, - 3, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1, 3>, - Sequence<0, 2, 1, 3>, - 3, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, - true, - true>; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - 0); - // copy result back to host - wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index ac75c56b..00000000 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,234 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - Tensor& wei_k_c_y_x, - const Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - // using vector load 4, so config's wo*ho must be a multiple of 4 - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - // using vector load 4, so config's wo*ho must be a multiple of 4 - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 1, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum_t::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 6381ce8b..00000000 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,288 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r4.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - GridSizeType desired_grid_size, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto N = in_n_hi_wi_c_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_desc.GetLength(I1); - const auto X = wei_k_y_x_c_desc.GetLength(I2); - - const auto GemmM = Y * X * C; - const auto GemmN = K; - const auto GemmKTotal = N * Ho * Wo; - - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); - const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; - - std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN - << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad - << std::endl; - - const auto descs = - transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - GemmKBatch, - GemmKPad); - - const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; - - constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum_t::AtomicAdd, - decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), - decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, - true>; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - 0); - // copy result back to host - wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index bc5d5996..00000000 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,276 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" -#include "debug.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; - -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; - - constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum_t::Set, - decltype(in_gemmk0_gemmm_gemmk1_grid_desc), - decltype(out_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, - true>(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - in_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmk0_gemmm_gemmk1_grid_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 603f8726..00000000 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,456 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r4.hpp" - -template -void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - Tensor& wei_k_y_x_c, - const Tensor& out_n_ho_wo_k, - GridSizeType desired_grid_size, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto N = in_n_hi_wi_c_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_desc.GetLength(I1); - const auto X = wei_k_y_x_c_desc.GetLength(I2); - - const auto GemmM = K; - const auto GemmN = Y * X * C; - const auto GemmKTotal = N * Ho * Wo; - - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); - const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; - - std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN - << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad - << std::endl; - - const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - GemmKBatch, - GemmKPad); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto wei_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; - - const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< - BlockSize, - TIn, - TAcc, - TWei, - InMemoryDataOperationEnum_t::AtomicAdd, - decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), - decltype(wei_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmM, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 3, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, - true>; - - // timing - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // verification - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - 0); - // copy result back to host - wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index e6554cf0..00000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,201 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_dlops_v1r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 1 - // cdata = 64, BlockSize = 256, 128x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); - - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - const auto wei_gemmk_gemmm_grid_desc = descs[I0]; - const auto in_gemmk_gemmn_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_dlops_v1r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk_gemmm_grid_desc), - decltype(in_gemmk_gemmn_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlockM1, - GemmNPerBlockN1, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM1100, - GemmM11N11ThreadClusterN1100, - GemmM11N11ThreadClusterM1101, - GemmM11N11ThreadClusterN1101, - GemmABlockTransferThreadSliceLengths_K_M0_M1, - GemmABlockTransferThreadClusterLengths_K_M0_M1, - Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder - Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder - 0, // ABlockTransferSrcVectorDim - GemmABlockTransferSrcScalarPerVector_K, - GemmABlockTransferDstScalarPerVector_M1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_K_N0_N1, - GemmBBlockTransferThreadClusterLengths_K_N0_N1, - Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder - Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - GemmBBlockTransferSrcScalarPerVector_N1, - GemmBBlockTransferDstScalarPerVector_N1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_N11, - decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), - decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gemmk_gemmm_grid_desc, - in_gemmk_gemmn_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk_gemmm0_gemmn1_grid_step_hacks, - in_gemmk_gemmn0_gemmn1_grid_step_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, - wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks, - in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 40685e81..00000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,273 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_dlops_v1r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmK1 = 1; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; - using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; - - using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; - - using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; - using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>; - - using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; - - using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; - using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 8, 2] for fp16 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmK1 = 2; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; - using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; - - using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; - using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; - - using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; - using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>; - - using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; - using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; - - using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; - using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 8, 4] for i8 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmK1 = 4; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; - using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; - - using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; - using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; - - using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; - using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>; - - using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; - using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; - - using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; - using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 - Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 - Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0 - Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10 - Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10 - Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11 - Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0 - Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 - Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 - - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; - - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_dlops_v1r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(in_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlockM1, - GemmNPerBlockN1, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM110Xs, - GemmM11N11ThreadClusterN110Xs, - GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, - GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, - Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder - Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder - GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder - GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1, - GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1, - Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder - Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder - GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder - GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_N11, - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>( - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index d65ecadb..00000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,228 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 0, 1, 2, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 1b23aa1a..00000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,414 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN - Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 - Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN - Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - - constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(in_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerXDL, - GemmNPerXDL, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - in_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index b5e5f91d..00000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,190 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" -#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" - -template -void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t /* nrepeat */) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = out_n_k_ho_wo_lengths[I0]; - const auto K = out_n_k_ho_wo_lengths[I1]; - const auto C = wei_k_c_y_x_lengths[I1]; - - const auto Hi = in_n_c_hi_wi_lengths[I2]; - const auto Wi = in_n_c_hi_wi_lengths[I3]; - - const auto Ho = out_n_k_ho_wo_lengths[I2]; - const auto Wo = out_n_k_ho_wo_lengths[I3]; - - const auto Y = wei_k_c_y_x_lengths[I2]; - const auto X = wei_k_c_y_x_lengths[I3]; - - const auto C0 = C / Number{}; - const auto C1 = Number{}; - - const auto K0 = K / Number{}; - const auto K1 = Number{}; - - Tensor in_n_c0_hi_wi_c1( - HostTensorDescriptor(std::initializer_list{N, C0, Hi, Wi, C1})); - Tensor wei_k_c0_y_x_c1( - HostTensorDescriptor(std::initializer_list{K, C0, Y, X, C1})); - Tensor out_n_k0_ho_wo_k1( - HostTensorDescriptor(std::initializer_list{N, K0, Ho, Wo, K1})); - - auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { - in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = - in_n_c_hi_wi(n, c, hi, wi); - }; - - auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) { - wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) = - wei_k_c_y_x(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); - make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); - - DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * - in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); - DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); - DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * - out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); - - in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); - wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); - - const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi)); - const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X)); - const auto out_n_k0_ho_wo_k1_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); - -#if 1 - // cdata = 64, BlockSize = 64, 16x8x32x4 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 16; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 32; - constexpr index_t EPerBlock = 1; - - constexpr index_t KPerThread = KPerBlock; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = EPerBlock; - - using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; - using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; - - constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K = 1; - - constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; - - constexpr index_t CThreadTransferDstScalarPerVector_W = 16; - - static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); -#else - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 16; - constexpr index_t HoPerBlock = 8; - constexpr index_t WoPerBlock = 32; - constexpr index_t EPerBlock = 1; - - constexpr index_t KPerThread = 16; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - constexpr index_t EPerThread = EPerBlock; - - using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; - using ABlockTransferThreadClusterLengths_E_K = Sequence; - - constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K = 1; - - constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; - - constexpr index_t CThreadTransferDstScalarPerVector_W = K1; - - static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); -#endif - - constexpr auto conv_driver = -#if 0 - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad -#else - DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad -#endif - ::type, - TAcc, - TOut, - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - BThreadTransferSrcScalarPerVector_W, - CThreadTransferDstScalarPerVector_W>{}; - - conv_driver.Run(wei_k_c0_y_x_desc, - in_n_c0_hi_wi_desc, - out_n_k0_ho_wo_k1_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - static_cast::type*>( - wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), - static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer())); - - out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); - - auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_n_k_ho_wo(n, k, ho, wo) = - out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); - }; - - make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index e1b7c548..00000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,241 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "driver_contraction_dlops_v1r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 1 - // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GN0 = 4; - constexpr index_t GK1 = 1; - - constexpr index_t GM1PerBlockGM11 = 128; - constexpr index_t GN1PerBlockGN11 = 32; - constexpr index_t GK0PerBlock = 8; - - constexpr index_t BM1PerThreadBM11 = 4; - constexpr index_t BN1PerThreadBN11 = 4; - constexpr index_t BK0PerThread = 1; - - using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; - using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; - - using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; - - using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; - - using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; - using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; - - using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - - constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; -#elif 1 - // [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16 - // cdata = 64, BlockSize = 256 - constexpr index_t BlockSize = 256; - - constexpr index_t GN0 = 4; - constexpr index_t GK1 = 2; - - constexpr index_t GM1PerBlockGM11 = 128; - constexpr index_t GN1PerBlockGN11 = 32; - constexpr index_t GK0PerBlock = 8; - - constexpr index_t BM1PerThreadBM11 = 4; - constexpr index_t BN1PerThreadBN11 = 4; - constexpr index_t BK0PerThread = 1; - - using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; - using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; - - using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; - using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; - - using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; - - using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; - using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; - - using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; - - constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; -#endif - - const auto descs = - transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x, - in_desc_n_c_hi_wi, - out_desc_n_k_ho_wo, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - Number{}); - - const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; - const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 - Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - - constexpr auto in_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - - constexpr auto out_grid_step_hacks = make_tuple( - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 - make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 - - constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; - - constexpr auto in_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_contraction_dlops_v1r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_grid_desc_gk0_gm0_gm1_gk1), - decltype(in_grid_desc_gk0_gn0_gn1_gk1), - decltype(out_grid_desc_gm0_gm1_gn0_gn1), - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder - Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder - Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder - Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - CThreadTransferDstScalarPerVector_BN1, - decltype(wei_grid_step_hacks), - decltype(in_grid_step_hacks), - decltype(out_grid_step_hacks), - decltype(wei_grid_move_slice_window_step_hacks), - decltype(in_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_grid_desc_gk0_gm0_gm1_gk1, - in_grid_desc_gk0_gn0_gn1_gk1, - out_grid_desc_gm0_gm1_gn0_gn1, - wei_grid_step_hacks, - in_grid_step_hacks, - out_grid_step_hacks, - wei_grid_move_slice_window_step_hacks, - in_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast(calculate_convolution_flops( - in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp deleted file mode 100644 index c44aa7d9..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp +++ /dev/null @@ -1,463 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_kn_mn(const Tensor& a_k_m, - const Tensor& b_k_n, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp deleted file mode 100644 index abaaf321..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp +++ /dev/null @@ -1,263 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_kn_nm(const Tensor& a_k_m, - const Tensor& b_k_n, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp deleted file mode 100644 index 0a97d361..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp +++ /dev/null @@ -1,463 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_nk_mn(const Tensor& a_k_m, - const Tensor& b_n_k, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp deleted file mode 100644 index d51caa38..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp +++ /dev/null @@ -1,263 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_km_nk_nm(const Tensor& a_k_m, - const Tensor& b_n_k, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_k_m_device_buf.ToDevice(a_k_m.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_k_m.mDesc.GetLengths()[0]; - const auto M = a_k_m.mDesc.GetLengths()[1]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], - a_k_m.mDesc.GetStrides()[1], - a_k_m.mDesc.GetStrides()[0])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<0, 2, 1>, - 1, - ABlockTransferSrcScalarPerVector_M, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp deleted file mode 100644 index 30ede251..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ /dev/null @@ -1,463 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp deleted file mode 100644 index 58ac3880..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp +++ /dev/null @@ -1,291 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_kn_nm(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_k_n.mDesc.GetLengths()[1]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], - b_k_n.mDesc.GetStrides()[1], - b_k_n.mDesc.GetStrides()[0])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<0, 2, 1>, - Sequence<0, 2, 1>, - 1, - BBlockTransferSrcScalarPerVector_N, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp deleted file mode 100644 index e99d5704..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp +++ /dev/null @@ -1,564 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_nk_mn(const Tensor& a_m_k, - const Tensor& b_n_k, - Tensor& c_m_n, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 64; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - -#if 1 - // non-padded GEMM - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; -#else - // padded GEMM - const auto a_k0_m_k1_grid_desc_tmp = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M; - - const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp, - make_tuple(make_pass_through_transform(K0), - make_right_pad_transform(M, MRightPad), - make_pass_through_transform(K1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc_tmp, - make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0, 0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; -#endif - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, - 7, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false, // CAccessOrderMRepeatNRepeat - true, // ABlockLdsExtraM - true // BBlockLdsExtraN - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - debug::debug_driver_gemm_xdlops_v2r3::M01, - debug::debug_driver_gemm_xdlops_v2r3::N01, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n.mData.data()); -} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp deleted file mode 100644 index a12cf073..00000000 --- a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp +++ /dev/null @@ -1,347 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_gemm_xdlops_mk_nk_nm(const Tensor& a_m_k, - const Tensor& b_n_k, - Tensor& c_n_m, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); - DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_n_k_device_buf.ToDevice(b_n_k.mData.data()); - c_n_m_device_buf.ToDevice(c_n_m.mData.data()); - -#if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 256; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 256; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 128; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 128; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t MPerBlock = 64; - constexpr index_t NPerBlock = 128; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 4; -#endif - - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_n_k.mDesc.GetLengths()[0]; - - constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; - - const auto a_k0_m_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), - make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], - a_m_k.mDesc.GetStrides()[0], - a_m_k.mDesc.GetStrides()[1])); - - const auto b_k0_n_k1_grid_desc = - make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), - make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], - b_n_k.mDesc.GetStrides()[0], - b_n_k.mDesc.GetStrides()[1])); - - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: M - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: M - Sequence<0>{})); // 2-: K1 - - constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 - Sequence<0>{}, // 1+: N - Sequence<0>{}), // 2+: K1 - make_tuple(Sequence<0>{}, // 0-: K0 - Sequence<0>{}, // 1-: N - Sequence<0>{})); // 2-: K1 - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = - driver_gemm_xdlops_v2r3, - Sequence<1, 0, 2>, - 2, - ABlockTransferSrcScalarPerVector_K1, - ABlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - BBlockTransferSrcScalarPerVector_K1, - BBlockTransferDstScalarPerVector_K1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), - decltype(b_k0_n_k1_grid_step_hacks), - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_n_k_device_buf.GetDeviceBuffer()), - static_cast(c_n_m_device_buf.GetDeviceBuffer()), - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m_n_grid_desc, - a_k0_m_k1_grid_step_hacks, - b_k0_n_k1_grid_step_hacks, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - a_k0_m_k1_grid_move_slice_window_step_hacks, - b_k0_n_k1_grid_move_slice_window_step_hacks, - nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - c_n_m_device_buf.FromDevice(c_n_m.mData.data()); -} diff --git a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp deleted file mode 100644 index d207728a..00000000 --- a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp +++ /dev/null @@ -1,286 +0,0 @@ -#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP -#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_contraction_dlops_v1r2.hpp" - -template -__host__ float -driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseContraction = - GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - CGlobalMemoryDataOperation, - AGridDesc_GK0_GM0_GM1_GK1, - BGridDesc_GK0_GN0_GN1_GK1, - CGridDesc_GM0_GM1_GN0_GN1, - GM1PerBlockGM11, - GN1PerBlockGN11, - GK0PerBlock, - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks>; - - const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - - if(!GridwiseContraction::CheckValidity( - a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) - { - throw std::runtime_error("wrong! " - "GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" - "GM0_GM1_GN0_GN1 has invalid setting"); - } - - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = - GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = - GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); - - using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1); - using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1); - - // c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1); - - using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); - - // c_grid_block_cluster_blockid_to_gm10_gn10 - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = - GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1); - - using CGridBlockCluster_BlockId_To_GM10_GN10 = - decltype(c_grid_block_cluster_blockid_to_gm10_gn10); - - const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1); - - const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); - - const bool has_double_tail_k_block_loop = - GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); - - { - std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{" - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", " - << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl; - - std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{" - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", " - << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl; - - std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", " - << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl; - } - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - else - { - const auto kernel = kernel_contraction_dlops_v1r2< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10); - } - - return ave_time; -} -#endif diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index efd4ce6a..00000000 --- a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,349 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP -#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v2.hpp" -#include "gridwise_operation_wrapper.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad -{ - template - __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, - const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto K = wei_k_c_y_x_global_desc.GetLength(I0); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_e_k_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Ho), - make_pass_through_transform(Wo)), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // output tensor - const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pass_through_transform(Ho), - make_pass_through_transform(Wo)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto E = C * Y * X; - - if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 && - (E % EPerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - - constexpr auto b_e_n_ho_wo_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - -#if 1 - // GEMM - using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - decltype(wei_e_k_global_desc), - decltype(in_e_n_ho_wo_global_desc), - decltype(out_k_n_ho_wo_global_desc), - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 3, 1>, - 3, - BThreadTransferSrcScalarPerVector_W, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<0, 2, 3, 1>, - 0, - CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_step_hacks), - decltype(b_e_n_ho_wo_global_step_hacks), - decltype(c_k_n_ho_wo_global_tensor_step_hacks), - decltype(a_e_k_global_move_slice_window_step_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; - - const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; - - const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - std::cout << "has_main_k_block_loop: " << has_main_k_block_loop - << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop - << std::endl; - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_ho_wo_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = - static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k0_ho_wo_k1_global_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } -#endif - } -}; -#endif diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp deleted file mode 100644 index 70f73cbf..00000000 --- a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp +++ /dev/null @@ -1,364 +0,0 @@ -#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP -#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v2.hpp" -#include "gridwise_operation_wrapper.hpp" - -template -struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad -{ - template - __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, - const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, - const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const FloatAB* __restrict__ p_wei_global, - const FloatAB* __restrict__ p_in_global, - FloatC* __restrict__ p_out_global) const - { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); - const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); - - const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); - - const auto K = wei_k_c_y_x_global_desc.GetLength(I0); - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; - const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; - - const auto OutRightPadH = Hop - Ho; - const auto OutRightPadW = Wop - Wo; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; - const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; - - std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW - << std::endl; - std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW - << std::endl; - - // weight tensor - const auto wei_e_k_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple( - make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_pass_through_transform(N), - make_pass_through_transform(Hop), - make_pass_through_transform(Wop)), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // output tensor - const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_pad_transform(Ho, 0, OutRightPadH), - make_pad_transform(Wo, 0, OutRightPadW)), - make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto E = C * Y * X; - - std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; - - if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && - (E % EPerBlock) == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - - constexpr auto b_e_n_ho_wo_global_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - // GEMM - using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperationEnum_t::Set, - decltype(wei_e_k_global_desc), - decltype(in_e_n_ho_wo_global_desc), - decltype(out_k_n_hop_wop_global_desc), - KPerBlock, - HoPerBlock, - WoPerBlock, - EPerBlock, - KPerThread, - HoPerThread, - WoPerThread, - EPerThread, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, - false, // don't move back src coordinate after threadwise copy - Sequence<0, 2, 3, 1>, - 3, - BThreadTransferSrcScalarPerVector_W, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<0, 2, 3, 1>, - 0, - CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_step_hacks), - decltype(b_e_n_ho_wo_global_step_hacks), - decltype(c_k_n_ho_wo_global_tensor_step_hacks), - decltype(a_e_k_global_move_slice_window_step_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; - - const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; - - const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; - - index_t nrepeat = 100; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - std::cout << "has_main_k_block_loop: " << has_main_k_block_loop - << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop - << std::endl; - - for(index_t j = 0; j < nrepeat; ++j) - { - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - else - { - const auto kernel = - run_gridwise_operation, - integral_constant>; - - launch_kernel(kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - wei_e_k_global_desc, - p_wei_global, - in_e_n_ho_wo_global_desc, - p_in_global, - out_k_n_hop_wop_global_desc, - p_out_global, - integral_constant{}, - integral_constant{}); - } - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = - static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k0_ho_wo_k1_global_desc)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } -}; -#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp deleted file mode 100644 index bf5f7f1c..00000000 --- a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp +++ /dev/null @@ -1,413 +0,0 @@ -#ifndef DRIVER_GEMM_DLOPS_V1R2 -#define DRIVER_GEMM_DLOPS_V1R2 - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v1r2.hpp" - -template -__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AKMGridDesc& a_k_m_grid_desc, - const BKNGridDesc& b_k_n_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2; - - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting"); - } - - const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // c_blockid_to_m0_n0_block_cluster_adaptor - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); - - { - std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " - << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " - << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); - DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); - DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); - DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( - sizeof(CBlockIdToM0N0BlockClusterAdaptor)); - - a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); - b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( - &c_blockid_to_m0_n0_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - - return ave_time; -#endif -} -#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp deleted file mode 100644 index 44709188..00000000 --- a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp +++ /dev/null @@ -1,418 +0,0 @@ -#ifndef DRIVER_GEMM_DLOPS_V1R3 -#define DRIVER_GEMM_DLOPS_V1R3 - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_dlops_v1r3.hpp" - -template -__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = - GridwiseGemmDlops_km_kn_mn_v1r3; - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting"); - } - - const auto a_k0_m0_m1_k1_grid_desc = - GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); - const auto b_k0_n0_n1_k1_grid_desc = - GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); - - using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); - using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // c_blockid_to_m0_n0_block_cluster_adaptor - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); - - { - std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); - DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); - DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); - DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( - sizeof(CBlockIdToM0N0BlockClusterAdaptor)); - - a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); - b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( - &c_blockid_to_m0_n0_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = - kernel_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space( - a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - - return ave_time; -#endif -} -#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp deleted file mode 100644 index 4ccfbaab..00000000 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ /dev/null @@ -1,275 +0,0 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP -#define DRIVER_GEMM_XDLOPS_V2R3_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" - -template -__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - ck::index_t M01, - ck::index_t N01, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - using GridwiseGemm = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - { - std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity( - a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - - using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - - const auto c_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01); - - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); - - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); - DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); - DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); - - a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); - b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } -} -#endif - return ave_time; -} -#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp deleted file mode 100644 index 30ecb02d..00000000 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp +++ /dev/null @@ -1,278 +0,0 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R4 -#define DRIVER_GEMM_XDLOPS_V2R4 - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r4.hpp" - -template -__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, - const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - ck::index_t M01, - ck::index_t N01, - AGridStepHacks, - BGridStepHacks, - CGridStepHacks, - AGridMoveSliceWindowStepHacks, - BGridMoveSliceWindowStepHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - using GridwiseGemm = - GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4; - - { - std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_b_k0_m_k1_grid_desc.GetLength(I1) << ", " - << a_b_k0_m_k1_grid_desc.GetLength(I2) << ", " - << a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_b_k0_n_k1_grid_desc.GetLength(I1) << ", " - << b_b_k0_n_k1_grid_desc.GetLength(I2) << ", " - << b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity( - a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"); - } - - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - - using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - - const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); - - const auto c_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch); - - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch); - { - std::cout << "gridSize : " << grid_size << std::endl; - } - - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); - } - -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); - DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc)); - DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); - DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); - - a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc); - b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc); - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); - } -#endif - return ave_time; -} -#endif diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp deleted file mode 100644 index 366b5dff..00000000 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ /dev/null @@ -1,346 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "host_conv_bwd_data.hpp" -#include "device_tensor.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp" - -#define USE_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 0 -#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 - -enum ConvBackwardDataAlgo -{ - V4R1XDLNHWC, // 0 - V4R1R2XDLNHWC, // 1 -}; - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_MODE - // dynamic mode - if(argc != 22) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t N = std::stoi(argv[7]); - const index_t K = std::stoi(argv[8]); - const index_t C = std::stoi(argv[9]); - const index_t Y = std::stoi(argv[10]); - const index_t X = std::stoi(argv[11]); - const index_t Hi = std::stoi(argv[12]); - const index_t Wi = std::stoi(argv[13]); - - const index_t conv_stride_h = std::stoi(argv[14]); - const index_t conv_stride_w = std::stoi(argv[15]); - const index_t conv_dilation_h = std::stoi(argv[16]); - const index_t conv_dilation_w = std::stoi(argv[17]); - const index_t in_left_pad_h = std::stoi(argv[18]); - const index_t in_left_pad_w = std::stoi(argv[19]); - const index_t in_right_pad_h = std::stoi(argv[20]); - const index_t in_right_pad_w = std::stoi(argv[21]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 7) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - constexpr auto N = Number<128>{}; - constexpr auto C = Number<192>{}; - constexpr auto Hi = Number<71>{}; - constexpr auto Wi = Number<71>{}; - constexpr auto K = Number<256>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - - constexpr auto conv_stride_h = I2; - constexpr auto conv_stride_w = I2; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 0 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - if(layout == ConvTensorLayout::NCHW) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - } - else if(layout == ConvTensorLayout::NHWC) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - } - else - { - throw std::runtime_error("wrong! not implemented"); - } - - Tensor in_host(in_lengths_host); - Tensor in_device(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor out(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - auto f_make_for_device_nhwc = [&]() { -#if USE_MODE - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_BWD_V4R1_XDL_NHWC - if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); - } -#endif - -#if USE_CONV_BWD_V4R1R2_XDL_NHWC - if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 && - in_right_pad_w == 0) - { - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1< - in_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); - } - else - { -#if 1 - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); -#endif - } - } -#endif - - if(do_verification) - { - host_direct_convolution_backward_data(in_host, - wei, - out, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(in_host, in_device); - - if(do_log) - { - LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") << std::endl; - } - } -} diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp deleted file mode 100644 index 48eba2b3..00000000 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ /dev/null @@ -1,447 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "host_conv.hpp" -#include "device_tensor.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" - -#define USE_DYNAMIC_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 0 -#define USE_CONV_FWD_V4R4R2_NHWC 0 -#define USE_CONV_FWD_V6R1_NCHW 0 -#define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 - -enum ConvForwardAlgo -{ - V4R4NCHW, // 0 - V4R4R2NHWC, // 1 - V6R1NCHW, // 2 - V5R1NCHW, // 3 - V4R4R2XDLNCHW, // 4 - V4R4R4XDLNHWC // 5 -}; - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 22) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t N = std::stoi(argv[7]); - const index_t K = std::stoi(argv[8]); - const index_t C = std::stoi(argv[9]); - const index_t Y = std::stoi(argv[10]); - const index_t X = std::stoi(argv[11]); - const index_t Hi = std::stoi(argv[12]); - const index_t Wi = std::stoi(argv[13]); - - const index_t conv_stride_h = std::stoi(argv[14]); - const index_t conv_stride_w = std::stoi(argv[15]); - const index_t conv_dilation_h = std::stoi(argv[16]); - const index_t conv_dilation_w = std::stoi(argv[17]); - const index_t in_left_pad_h = std::stoi(argv[18]); - const index_t in_left_pad_w = std::stoi(argv[19]); - const index_t in_right_pad_h = std::stoi(argv[20]); - const index_t in_right_pad_w = std::stoi(argv[21]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 7) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - constexpr auto N = Number<128>{}; - constexpr auto C = Number<192>{}; - constexpr auto Hi = Number<71>{}; - constexpr auto Wi = Number<71>{}; - constexpr auto K = Number<256>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - - constexpr auto conv_stride_h = I2; - constexpr auto conv_stride_w = I2; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 0 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#elif 1 - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - if(layout == ConvTensorLayout::NCHW) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - } - else if(layout == ConvTensorLayout::NHWC) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - } - else - { - std::runtime_error("wrong! not implemented"); - } - - Tensor in(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor out_host(out_lengths_host); - Tensor out_device(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - auto f_make_for_device_nchw = [&]() { - const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); - const auto wei_lengths_dev = make_tuple(K, C, Y, X); - const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - - auto f_make_for_device_nhwc = [&]() { - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - -#if USE_CONV_FWD_V4R4_NCHW - if(algo == ConvForwardAlgo::V4R4NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R2_NHWC - if(algo == ConvForwardAlgo::V4R4R2NHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V6R1_NCHW - if(algo == ConvForwardAlgo::V6R1NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V5R1_NCHW - if(algo == ConvForwardAlgo::V5R1NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R2_XDL_NCHW - if(algo == ConvForwardAlgo::V4R4R2XDLNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4R4_XDL_NHWC - if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - - if(do_verification) - { - host_direct_convolution(in, - wei, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(out_host, out_device); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; - } - } -} diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp deleted file mode 100644 index 50f4d6a9..00000000 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ /dev/null @@ -1,436 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "host_conv_bwd_weight.hpp" -#include "device_tensor.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" - -#define USE_DYNAMIC_MODE 1 -#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 -#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 -#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 -#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0 -#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 - -enum ConvBackwardWeightAlgo -{ - V4R4R2XDLNCHW, // 0 - V4R4R4XDLNHWC, // 1 - V4R4R2XDLATOMICNCHW, // 2 - V4R4R4XDLATOMICNHWC, // 3 - V4R4R5XDLATOMICNHWC, // 4 -}; - -int main(int argc, char* argv[]) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - -#if USE_DYNAMIC_MODE - // dynamic mode - if(argc != 23) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - printf("additional: desired_grid_size\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t N = std::stoi(argv[7]); - const index_t K = std::stoi(argv[8]); - const index_t C = std::stoi(argv[9]); - const index_t Y = std::stoi(argv[10]); - const index_t X = std::stoi(argv[11]); - const index_t Hi = std::stoi(argv[12]); - const index_t Wi = std::stoi(argv[13]); - - const index_t conv_stride_h = std::stoi(argv[14]); - const index_t conv_stride_w = std::stoi(argv[15]); - const index_t conv_dilation_h = std::stoi(argv[16]); - const index_t conv_dilation_w = std::stoi(argv[17]); - const index_t in_left_pad_h = std::stoi(argv[18]); - const index_t in_left_pad_w = std::stoi(argv[19]); - const index_t in_right_pad_h = std::stoi(argv[20]); - const index_t in_right_pad_w = std::stoi(argv[21]); - - const index_t desired_grid_size = std::stoi(argv[22]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; -#else - // static mode - if(argc < 7) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); - const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - constexpr auto N = Number<128>{}; - constexpr auto C = Number<128>{}; - constexpr auto Hi = Number<14>{}; - constexpr auto Wi = Number<14>{}; - constexpr auto K = Number<256>{}; - constexpr auto Y = Number<3>{}; - constexpr auto X = Number<3>{}; - - constexpr auto conv_stride_h = I1; - constexpr auto conv_stride_w = I1; - constexpr auto conv_dilation_h = I1; - constexpr auto conv_dilation_w = I1; - constexpr auto in_left_pad_h = I1; - constexpr auto in_left_pad_w = I1; - constexpr auto in_right_pad_h = I1; - constexpr auto in_right_pad_w = I1; - - constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; - constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - - constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; - constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; -#endif - -#if 0 - using in_data_t = float; - using wei_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - using out_data_t = half_t; - using acc_data_t = float; - using wei_data_t = float; -#elif 1 - using in_data_t = int8_t; - using out_data_t = int8_t; - using acc_data_t = int32_t; - using wei_data_t = int8_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - if(layout == ConvTensorLayout::NCHW) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - } - else if(layout == ConvTensorLayout::NHWC) - { - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - } - else - { - std::runtime_error("wrong! not implemented"); - } - - Tensor in(in_lengths_host); - Tensor wei_device(wei_lengths_host); - Tensor wei_host(wei_lengths_host); - Tensor out(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); - out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_out = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - out.GenerateTensorValue(gen_out, num_thread); - } - - auto f_make_for_device_nchw = [&]() { - const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); - const auto wei_lengths_dev = make_tuple(K, C, Y, X); - const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - - auto f_make_for_device_nhwc = [&]() { - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - - // set zero to wei_device - wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread); -#if USE_CONV_WRW_V4R4R2_XDL_NCHW - if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R4_XDL_NHWC - if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW - if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw< - in_data_t, - wei_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - desired_grid_size, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC - if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk< - in_data_t, - wei_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - desired_grid_size, - nrepeat); - } -#endif - -#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC - if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk< - in_data_t, - wei_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei_device, - out, - desired_grid_size, - nrepeat); - } -#endif - - if(do_verification) - { - host_direct_convolution_backward_weights(out, - in, - wei_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(wei_host, wei_device); - - if(do_log) - { - LogRangeAsType(std::cout << "out: ", out.mData, ",") << std::endl; - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl; - } - } -} diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp deleted file mode 100644 index e60b4905..00000000 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ /dev/null @@ -1,288 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "debug.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "gemm_common.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdlops_mk_kn_mn.hpp" -#include "device_gemm_xdlops_mk_nk_mn.hpp" -#include "device_gemm_xdlops_km_kn_mn.hpp" -#include "device_gemm_xdlops_km_nk_mn.hpp" -#include "device_gemm_xdlops_mk_kn_nm.hpp" -#include "device_gemm_xdlops_mk_nk_nm.hpp" -#include "device_gemm_xdlops_km_kn_nm.hpp" -#include "device_gemm_xdlops_km_nk_nm.hpp" - -#define USE_GEMM_XDL_MK_KN_MN 1 -#define USE_GEMM_XDL_MK_NK_MN 1 -#define USE_GEMM_XDL_KM_KN_MN 1 -#define USE_GEMM_XDL_KM_NK_MN 1 -#define USE_GEMM_XDL_MK_KN_NM 0 -#define USE_GEMM_XDL_MK_NK_NM 0 -#define USE_GEMM_XDL_KM_KN_NM 0 -#define USE_GEMM_XDL_KM_NK_NM 0 - -enum GemmAlgo -{ - Xdl_MK_KN_MN, // 0 - Xdl_MK_NK_MN, // 1 - Xdl_KM_KN_MN, // 2 - Xdl_KM_NK_MN, // 3 - Xdl_MK_KN_NM, // 4 - Xdl_MK_NK_NM, // 5 - Xdl_KM_KN_NM, // 6 - Xdl_KM_NK_NM, // 7 -}; - -int main(int argc, char* argv[]) -{ - using namespace ck; - - if(argc != 12) - { - printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: M, N, K\n"); - printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n"); - exit(1); - } - - const auto layout = static_cast(std::stoi(argv[1])); - const auto algo = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); - - const index_t M = std::stoi(argv[7]); - const index_t N = std::stoi(argv[8]); - const index_t K = std::stoi(argv[9]); - - debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]); - debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]); - -#if 0 - using ab_data_t = float; - using acc_data_t = float; - using c_data_t = float; -#elif 1 - using ab_data_t = half_t; - using acc_data_t = float; - using c_data_t = half_t; -#elif 1 - using ab_data_t = int8_t; - using acc_data_t = int32_t; - using c_data_t = int8_t; -#endif - - std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); - std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); - - // A - if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN || - layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM) - { - a_lengths_host[0] = static_cast(M); - a_lengths_host[1] = static_cast(K); - a_strides_host[0] = static_cast(K); - a_strides_host[1] = static_cast(1); - } - else - { - a_lengths_host[0] = static_cast(K); - a_lengths_host[1] = static_cast(M); - a_strides_host[0] = static_cast(M); - a_strides_host[1] = static_cast(1); - } - - // B - if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN || - layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM) - { - b_lengths_host[0] = static_cast(N); - b_lengths_host[1] = static_cast(K); - b_strides_host[0] = static_cast(K); - b_strides_host[1] = static_cast(1); - } - else - { - b_lengths_host[0] = static_cast(K); - b_lengths_host[1] = static_cast(N); - b_strides_host[0] = static_cast(N); - b_strides_host[1] = static_cast(1); - } - - // C - if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN || - layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN) - { - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); - } - else - { - c_lengths_host[0] = static_cast(N); - c_lengths_host[1] = static_cast(M); - c_strides_host[0] = static_cast(M); - c_strides_host[1] = static_cast(1); - } - - Tensor a(a_lengths_host, a_strides_host); - Tensor b(b_lengths_host, b_strides_host); - Tensor c_host(c_lengths_host, c_strides_host); - Tensor c_device(c_lengths_host, c_strides_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: "); - ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); - ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - } - -#if USE_GEMM_XDL_MK_KN_MN - if(algo == GemmAlgo::Xdl_MK_KN_MN) - { - if(layout != GemmMatrixLayout::MK_KN_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_kn_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_MK_NK_MN - if(algo == GemmAlgo::Xdl_MK_NK_MN) - { - if(layout != GemmMatrixLayout::MK_NK_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_nk_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_KN_MN - if(algo == GemmAlgo::Xdl_KM_KN_MN) - { - if(layout != GemmMatrixLayout::KM_KN_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_kn_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_NK_MN - if(algo == GemmAlgo::Xdl_KM_NK_MN) - { - if(layout != GemmMatrixLayout::KM_NK_MN) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_nk_mn(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_MK_KN_NM - if(algo == GemmAlgo::Xdl_MK_KN_NM) - { - if(layout != GemmMatrixLayout::MK_KN_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_kn_nm(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_MK_NK_NM - if(algo == GemmAlgo::Xdl_MK_NK_NM) - { - if(layout != GemmMatrixLayout::MK_NK_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_mk_nk_nm(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_KN_NM - if(algo == GemmAlgo::Xdl_KM_KN_NM) - { - if(layout != GemmMatrixLayout::KM_KN_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_kn_nm(a, b, c_device, nrepeat); - } -#endif - -#if USE_GEMM_XDL_KM_NK_NM - if(algo == GemmAlgo::Xdl_KM_NK_NM) - { - if(layout != GemmMatrixLayout::KM_NK_NM) - { - throw std::runtime_error("wrong! layout"); - } - - device_gemm_xdlops_km_nk_nm(a, b, c_device, nrepeat); - } -#endif - - if(do_verification) - { - host_gemm(a, b, c_host, layout); - - check_error(c_host, c_device); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; - } - } -} diff --git a/host/host_tensor/CMakeLists.txt b/host/host_tensor/CMakeLists.txt deleted file mode 100644 index 3dcecf64..00000000 --- a/host/host_tensor/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -include_directories(BEFORE - include -) - -set(HOST_TENSOR_SOURCE - src/host_tensor.cpp; - src/device.cpp; -) - -## the library target -add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) - -target_include_directories(host_tensor SYSTEM PUBLIC $) - -target_link_libraries(host_tensor PRIVATE hip::device) -target_link_libraries(host_tensor INTERFACE hip::host) - -target_compile_features(host_tensor PUBLIC) -set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) - -install(TARGETS host_tensor LIBRARY DESTINATION lib) diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp deleted file mode 100644 index cb1a6eff..00000000 --- a/host/host_tensor/include/device.hpp +++ /dev/null @@ -1,84 +0,0 @@ -#ifndef DEVICE_HPP -#define DEVICE_HPP - -#include -#include -#include -#include -#include "hip/hip_runtime.h" -#include "hip/hip_fp16.h" - -struct DeviceMem -{ - DeviceMem() = delete; - DeviceMem(std::size_t mem_size); - void* GetDeviceBuffer(); - void ToDevice(const void* p); - void FromDevice(void* p); - ~DeviceMem(); - - void* mpDeviceBuf; - std::size_t mMemSize; -}; - -struct KernelTimerImpl; - -struct KernelTimer -{ - KernelTimer(); - ~KernelTimer(); - void Start(); - void End(); - float GetElapsedTime() const; - - std::unique_ptr impl; -}; - -using device_stream_t = hipStream_t; - -template -void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) -{ - hipStream_t stream_id = nullptr; - - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); -} - -template -float launch_and_time_kernel( - F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) -{ - KernelTimer timer; - - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up\n"); - - hipStream_t stream_id = nullptr; - - // warm up - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); - - printf("Start running %d times...\n", nrepeat); - - timer.Start(); - - for(int i = 0; i < nrepeat; ++i) - { - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); - } - - timer.End(); - - // std::this_thread::sleep_for (std::chrono::microseconds(10)); - - return timer.GetElapsedTime() / nrepeat; -} -#endif diff --git a/host/host_tensor/include/device_tensor.hpp b/host/host_tensor/include/device_tensor.hpp deleted file mode 100644 index 1a7a34a4..00000000 --- a/host/host_tensor/include/device_tensor.hpp +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once -#include "host_tensor.hpp" -#include "common_header.hpp" - -template -void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) -{ - ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os); -} diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp deleted file mode 100644 index f6c0d6f9..00000000 --- a/host/host_tensor/include/gemm_common.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef GEMM_COMMON_HPP -#define GEMM_COMMON_HPP - -enum GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 -}; - -#endif diff --git a/host/host_tensor/include/host_conv.hpp b/host/host_tensor/include/host_conv.hpp deleted file mode 100644 index c1228f48..00000000 --- a/host/host_tensor/include/host_conv.hpp +++ /dev/null @@ -1,324 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution(const Tensor& in, - const Tensor& wei, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(wei(k, c, y, x)); - } - } - } - } - out(n, k, ho, wo) = v; - }; - - auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(wei(k, y, x, c)); - } - } - } - } - out(n, ho, wo, k) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} - -template -void host_winograd_3x3_convolution(const Tensor& in_nchw, - const Tensor& wei_kcyx, - Tensor& out_nkhw, - InLeftPads, - InRightPads) -{ - using namespace ck; - - constexpr std::size_t HoPerTile = 2; - constexpr std::size_t WoPerTile = 2; - - std::size_t N = in_nchw.mDesc.GetLengths()[0]; - std::size_t C = in_nchw.mDesc.GetLengths()[1]; - - std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; - std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; - std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; - - std::size_t Ho = out_nkhw.mDesc.GetLengths()[2]; - std::size_t Wo = out_nkhw.mDesc.GetLengths()[3]; - - index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); - index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); - - std::size_t HiPerTile = HoPerTile + Y - 1; - std::size_t WiPerTile = WoPerTile + X - 1; - - std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile; - std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile; - - Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); - Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); - Tensor wei_transform({K, C, HiPerTile, WiPerTile}); - Tensor out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile}); - Tensor out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile}); - - auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) { - for(int j = 0; j < HiPerTile; ++j) - { - int hi = HoPerTile * htile + j - h_pad_low; - for(int i = 0; i < WiPerTile; ++i) - { - int wi = WoPerTile * wtile + i - w_pad_low; - - if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_nchw.mDesc.GetLengths()[3]) - { - in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi); - } - else - { - in_hold(n, c, htile, wtile, j, i) = TIn(0); - } - } - } - }; - - auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) { - in_transform(n, c, htile, wtile, 0, 0) = - in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) - - in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 1) = - in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) - - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 2) = - -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 3) = - in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 1, 0) = - in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 1) = - in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 2) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 3) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 2, 0) = - -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 1) = - -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 2) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 3) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 3, 0) = - in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 1) = - in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 2) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 3) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) - - in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3); - }; - - auto f_wei_transform = [&](auto k, auto c) { - wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0)); - wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + - 0.5 * double(wei_kcyx(k, c, 0, 1)) + - 0.5 * double(wei_kcyx(k, c, 0, 2)); - wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - - 0.5 * double(wei_kcyx(k, c, 0, 1)) + - 0.5 * double(wei_kcyx(k, c, 0, 2)); - wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2)); - - wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + - 0.5 * double(wei_kcyx(k, c, 1, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 1, 1) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) + - 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 1, 2) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) - - 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) + - 0.5 * double(wei_kcyx(k, c, 1, 2)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - - wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - - 0.5 * double(wei_kcyx(k, c, 1, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 2, 1) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) - - 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 2, 2) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) + - 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) - - 0.5 * double(wei_kcyx(k, c, 1, 2)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - - wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 1)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) - - 0.5 * double(wei_kcyx(k, c, 2, 1)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2)); - }; - - auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { - for(int j = 0; j < HiPerTile; ++j) - { - for(int i = 0; i < WiPerTile; ++i) - { - double v = 0; - for(int c = 0; c < C; ++c) - { - v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i); - } - - out_transform(n, k, htile, wtile, j, i) = v; - } - } - }; - - auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) { - out_hold(n, k, htile, wtile, 0, 0) = - out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) + - out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) + - out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) + - out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) + - out_transform(n, k, htile, wtile, 2, 2); - out_hold(n, k, htile, wtile, 0, 1) = - out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) - - out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) - - out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) + - out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - - out_transform(n, k, htile, wtile, 2, 3); - out_hold(n, k, htile, wtile, 1, 0) = - out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) + - out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) - - out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - - out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) - - out_transform(n, k, htile, wtile, 3, 2); - out_hold(n, k, htile, wtile, 1, 1) = - out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) - - out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) + - out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) - - out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) + - out_transform(n, k, htile, wtile, 3, 3); - }; - - auto f_out = [&](auto n, auto k, auto htile, auto wtile) { - for(int j = 0; j < HoPerTile; ++j) - { - std::size_t ho = HoPerTile * htile + j; - for(int i = 0; i < WoPerTile; ++i) - { - std::size_t wo = WoPerTile * wtile + i; - out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); - } - } - }; - - std::size_t num_thread = std::thread::hardware_concurrency(); - - make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread); - make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread); -} diff --git a/host/host_tensor/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp deleted file mode 100644 index ca23422e..00000000 --- a/host/host_tensor/include/host_conv_bwd_data.hpp +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution_backward_data(Tensor& in, - const Tensor& wei, - const Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& /* in_right_pads */, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I2]; - std::size_t X = wei.mDesc.GetLengths()[I3]; - - std::size_t Ho = out.mDesc.GetLengths()[I2]; - std::size_t Wo = out.mDesc.GetLengths()[I3]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, k, ho, wo) * wei(k, c, y, x); - } - } - } - } - } - } - } - - in(n, c, hi, wi) = v; - }; - - auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I1]; - std::size_t X = wei.mDesc.GetLengths()[I2]; - - std::size_t Ho = out.mDesc.GetLengths()[I1]; - std::size_t Wo = out.mDesc.GetLengths()[I2]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, ho, wo, k) * wei(k, y, x, c); - } - } - } - } - } - } - } - - in(n, hi, wi, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_conv_bwd_weight.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp deleted file mode 100644 index ed3e8c30..00000000 --- a/host/host_tensor/include/host_conv_bwd_weight.hpp +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution_backward_weights( - const Tensor& out, - const Tensor& in, - Tensor& wei, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - auto f_kcyx = [&](auto k, auto c, auto y, auto x) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(out(n, k, ho, wo)); - } - } - } - } - wei(k, c, y, x) = v; - }; - - auto f_kyxc = [&](auto k, auto y, auto x, auto c) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(out(n, ho, wo, k)); - } - } - } - } - wei(k, y, x, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_kcyx, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_kyxc, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp deleted file mode 100644 index c582a342..00000000 --- a/host/host_tensor/include/host_gemm.hpp +++ /dev/null @@ -1,159 +0,0 @@ -#pragma once -#include "host_tensor.hpp" -#include "gemm_common.hpp" - -template -void host_gemm(const Tensor& a, - const Tensor& b, - Tensor& c, - const GemmMatrixLayout layout) -{ - if(layout == GemmMatrixLayout::MK_KN_MN) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(k, n)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_MN) - { - auto f_mk_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(n, k)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_MN) - { - auto f_km_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(k, n)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_MN) - { - auto f_km_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(n, k)); - } - - c(m, n) = v; - }; - - make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_KN_NM) - { - auto f_mk_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(k, n)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_NM) - { - auto f_mk_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(m, k)) * static_cast(b(n, k)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_NM) - { - auto f_km_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(k, n)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_NM) - { - auto f_km_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a(k, m)) * static_cast(b(n, k)); - } - - c(n, m) = v; - }; - - make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp deleted file mode 100644 index b0d53995..00000000 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef HOST_TENSOR_GENERATOR_HPP -#define HOST_TENSOR_GENERATOR_HPP - -#include -#include "config.hpp" - -struct GeneratorTensor_1 -{ - int value = 1; - - template - float operator()(Is...) - { - return value; - } -}; - -struct GeneratorTensor_0 -{ - int value = 0; - - template - float operator()(Is...) - { - return value; - } -}; - -struct GeneratorTensor_2 -{ - int min_value = 0; - int max_value = 1; - - template - float operator()(Is...) - { - return (std::rand() % (max_value - min_value)) + min_value; - } -}; - -template -struct GeneratorTensor_3 -{ - T min_value = 0; - T max_value = 1; - - template - float operator()(Is...) - { - float tmp = float(std::rand()) / float(RAND_MAX); - - return min_value + tmp * (max_value - min_value); - } -}; - -struct GeneratorTensor_Checkboard -{ - template - float operator()(Ts... Xs) const - { - std::array dims = {{static_cast(Xs)...}}; - return std::accumulate(dims.begin(), - dims.end(), - true, - [](bool init, ck::index_t x) -> int { return init != (x % 2); }) - ? 1 - : -1; - } -}; - -#endif diff --git a/host/host_tensor/src/device.cpp b/host/host_tensor/src/device.cpp deleted file mode 100644 index 0d1b3d68..00000000 --- a/host/host_tensor/src/device.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "device.hpp" - -DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) -{ - hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); -} - -void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } - -void DeviceMem::ToDevice(const void* p) -{ - hipGetErrorString( - hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); -} - -void DeviceMem::FromDevice(void* p) -{ - hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); -} - -DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } - -struct KernelTimerImpl -{ - KernelTimerImpl() - { - hipGetErrorString(hipEventCreate(&mStart)); - hipGetErrorString(hipEventCreate(&mEnd)); - } - - ~KernelTimerImpl() - { - hipGetErrorString(hipEventDestroy(mStart)); - hipGetErrorString(hipEventDestroy(mEnd)); - } - - void Start() - { - hipGetErrorString(hipDeviceSynchronize()); - hipGetErrorString(hipEventRecord(mStart, nullptr)); - } - - void End() - { - hipGetErrorString(hipEventRecord(mEnd, nullptr)); - hipGetErrorString(hipEventSynchronize(mEnd)); - } - - float GetElapsedTime() const - { - float time; - hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd)); - return time; - } - - hipEvent_t mStart, mEnd; -}; - -KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} - -KernelTimer::~KernelTimer() {} - -void KernelTimer::Start() { impl->Start(); } - -void KernelTimer::End() { impl->End(); } - -float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } diff --git a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 2b645e3c..00000000 --- a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,689 +0,0 @@ -#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP -#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP - -#include -#include - -namespace ck { -namespace driver { - -struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw -{ - auto GetCompileParameterString() const - { - auto param = std::stringstream(); - - // clang-format off - param << - " -DCK_PARAM_ABDataTypeEnum=" << - ABDataTypeEnum << - " -DCK_PARAM_AccDataTypeEnum=" << - AccDataTypeEnum << - " -DCK_PARAM_CDataTypeEnum=" << - CDataTypeEnum << - " -DCK_PARAM_BlockSize=" << - BlockSize << - " -DCK_PARAM_GN0=" << - GN0 << - " -DCK_PARAM_GK1=" << - GK1 << - " -DCK_PARAM_GM1PerBlockGM11=" - << GM1PerBlockGM11 << - " -DCK_PARAM_GN1PerBlockGN11=" << - GN1PerBlockGN11 << - " -DCK_PARAM_GK0PerBlock=" << - GK0PerBlock << - " -DCK_PARAM_BM1PerThreadBM11=" << - BM1PerThreadBM11 << - " -DCK_PARAM_BN1PerThreadBN11=" << - BN1PerThreadBN11 << - " -DCK_PARAM_BK0PerThread=" << - BK0PerThread << - " -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" << - BM10BN10ThreadClusterBM10Xs[0] << "," << - BM10BN10ThreadClusterBM10Xs[1] << - " -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" << - BM10BN10ThreadClusterBN10Xs[0] << "," << - BM10BN10ThreadClusterBN10Xs[1] << - " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] << - " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] << - " -DCK_PARAM_CThreadTransferDstScalarPerVector=" << - CThreadTransferDstScalarPerVector << - " -DCK_PARAM_HasMainKBlockLoop=" << - static_cast(HasMainKBlockLoop) << - " -DCK_PARAM_HasDoubleTailKBlockLoop=" << - static_cast(HasDoubleTailKBlockLoop); - // clang-format on - - return param.str(); - } - - ck::DataTypeEnum_t ABDataTypeEnum = ck::DataTypeEnum_t::Unknown; - ck::DataTypeEnum_t AccDataTypeEnum = ck::DataTypeEnum_t::Unknown; - ck::DataTypeEnum_t CDataTypeEnum = ck::DataTypeEnum_t::Unknown; - - int BlockSize = -1; - - int GN0 = -1; - int GK1 = -1; - - int GM1PerBlockGM11 = -1; - int GN1PerBlockGN11 = -1; - int GK0PerBlock = -1; - - int BM1PerThreadBM11 = -1; - int BN1PerThreadBN11 = -1; - int BK0PerThread = -1; - - std::array BM10BN10ThreadClusterBM10Xs = {-1, -1}; - std::array BM10BN10ThreadClusterBN10Xs = {-1, -1}; - - std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { - -1, -1, -1, -1, -1}; - - std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { - -1, -1, -1, -1, -1}; - - int CThreadTransferDstScalarPerVector = -1; - - bool HasMainKBlockLoop = false; - bool HasDoubleTailKBlockLoop = false; -}; - -struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw -{ - ck::DataTypeEnum_t ABDataTypeEnum; - ck::DataTypeEnum_t CDataTypeEnum; - - int BlockSize; - - int GN0; - int GK1; - - int GM1PerBlockGM11; - int GN1PerBlockGN11; - int GK0PerBlock; - - int BM1PerThreadBM11; - int BN1PerThreadBN11; - int BK0PerThread; - - std::array BM10BN10ThreadClusterBM10Xs; - std::array BM10BN10ThreadClusterBN10Xs; - - std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; - std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; - std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - - std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; - std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; - std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; -}; - -inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw() -{ - constexpr auto f32 = ck::DataTypeEnum_t::Float; - constexpr auto f16 = ck::DataTypeEnum_t::Half; - constexpr auto i8 = ck::DataTypeEnum_t::Int8; - - return std::vector{ - // clang-format off - // fp32 - {f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, - - {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - // fp16 - {f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, - - {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - {f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - {f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - // i8 - { i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, - - { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - { i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - { i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - { i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, - - { i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}} - // clang-format on - }; -} - -// TODO make this common interface and write specs for it -struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw -{ - static auto - CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc, - const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable) - { - const int C = conv_problem_desc.C; - const int Y = conv_problem_desc.Y; - const int X = conv_problem_desc.X; - const int Ho = conv_problem_desc.Ho; - const int Wo = conv_problem_desc.Wo; - - if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum && - conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum && - conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum)) - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - - const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum; - const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum; - - DataTypeEnum_t AccDataTypeEnum; - - if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half) - { - AccDataTypeEnum = DataTypeEnum_t::Float; - } - else if(ABDataTypeEnum == DataTypeEnum_t::Int8) - { - AccDataTypeEnum = DataTypeEnum_t::Int32; - } - else - { - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - } - - const int BlockSize = tunable.BlockSize; - - const int GN0 = tunable.GN0; - const int GK1 = tunable.GK1; - - const int GM11 = tunable.GM1PerBlockGM11; - const int GN11 = tunable.GN1PerBlockGN11; - const int GK0PerBlock = tunable.GK0PerBlock; - - const int BM11 = tunable.BM1PerThreadBM11; - const int BN11 = tunable.BN1PerThreadBN11; - const int BK0PerThread = tunable.BK0PerThread; - - const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs; - const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs; - - const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; - const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; - const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = - tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - - const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; - const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; - const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = - tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - - // C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim - const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo); - - const int C0 = GK1; - - if(!(C % C0 == 0)) - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - - const int C1 = C / C0; - - const int GK0 = C1 * Y * X; - - if(!(GK0 % GK0PerBlock == 0)) - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - - const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1); - - const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0); - - return std::make_tuple( - CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{ - ABDataTypeEnum, - AccDataTypeEnum, - CDataTypeEnum, - BlockSize, - GN0, - GK1, - GM11, - GN11, - GK0PerBlock, - BM11, - BN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - CThreadTransferDstScalarPerVector, - HasMainKBlockLoop, - HasDoubleTailKBlockLoop}, - true); - } - - static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc) - { - for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()) - { - CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{}; - bool found = false; - - std::tie(compile_param, found) = - CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable); - - if(found && IsValidCompileParameter(conv_problem_desc, compile_param)) - return std::make_tuple(compile_param, true); - } - - return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); - } - - static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc) - { - bool found = false; - - std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc); - - return found; - } - - static bool - IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) - { - const int N = conv_problem_desc.N; - const int K = conv_problem_desc.K; - const int C = conv_problem_desc.C; - const int Y = conv_problem_desc.Y; - const int X = conv_problem_desc.X; - const int Ho = conv_problem_desc.Ho; - const int Wo = conv_problem_desc.Wo; - - const int GK1 = compile_param.GK1; - const int GN0 = compile_param.GN0; - const int GM11 = compile_param.GM1PerBlockGM11; - const int GN11 = compile_param.GN1PerBlockGN11; - - const int BM11 = compile_param.BM1PerThreadBM11; - const int BN11 = compile_param.BN1PerThreadBN11; - - const int C0 = GK1; - const int N0 = GN0; - - if(!(C % C0 == 0)) - return false; - - const int C1 = C / C0; - - if(!(N % N0 == 0)) - return false; - - const int N1 = N / N0; - - const int GM0 = 1; - const int GM1 = K; - const int GN1 = N1 * Ho * Wo; - const int GK0 = C1 * Y * X; - - // check data type - { - if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum && - conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum)) - return false; - - if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float || - compile_param.ABDataTypeEnum == DataTypeEnum_t::Half) - { - if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float)) - return false; - } - else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8) - { - if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32)) - return false; - } - } - - // check gridwise contraction - { - if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0)) - return false; - - const bool has_main_k_block_loop = - ((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1); - - const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0); - - if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop && - has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop)) - return false; - } - - // check A blockwise copy - { - const auto block_slice_lengths = - std::array{compile_param.GK0PerBlock, GM0, 1, GM11, GK1}; - const auto& cluster_lengths = - compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; - const auto& thread_slice_lengths = - compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; - const auto& src_vector_lengths = - compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - const auto& dst_vector_lengths = - compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; - - // check number of working thread - const int num_work_thread = std::accumulate( - cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); - - if(!(compile_param.BlockSize >= num_work_thread)) - return false; - - // check block slice lengths vs thread slice lengths vs cluster lengths - for(int i = 0; i < 5; ++i) - { - if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) - return false; - } - - // check thread slice lengths vs vector lengths - for(int i = 0; i < 5; ++i) - { - if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0)) - return false; - - if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) - return false; - } - - // check Src vectorization, GK0 is global mem vector dim - if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 && - src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1)) - return false; - - // check Dst vectorization, {GM11, GK1} are LDS vector dims - if(dst_vector_lengths[4] == GK1) - { // vectorize on {GM11, GK1} - if(!(GM11 % dst_vector_lengths[3] == 0)) - return false; - } - else - { // vectorize on {GK1} only - if(!(GK1 % dst_vector_lengths[4] == 0)) - return false; - - if(!(dst_vector_lengths[3] == 1)) - return false; - } - } - - // check B blockwise copy - { - const auto block_slice_lengths = - std::array{compile_param.GK0PerBlock, GN0, 1, GN11, GK1}; - const auto& cluster_lengths = - compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; - const auto& thread_slice_lengths = - compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; - const auto& src_vector_lengths = - compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - const auto& dst_vector_lengths = - compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; - - // check number of working thread - const int num_work_thread = std::accumulate( - cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); - - if(!(compile_param.BlockSize >= num_work_thread)) - return false; - - // check block slice lengths vs thread slice lengths vs cluster lengths - for(int i = 0; i < 5; ++i) - { - if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) - return false; - } - - // check thread slice lengths vs vector lengths - for(int i = 0; i < 5; ++i) - { - if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 && - thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) - return false; - } - - // check Src vectorization: {GN11} is global mem vector dim - if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 && - src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1)) - return false; - - // check Src tensor layout related vectorization - if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 && - conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 && - conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 && - conv_problem_desc.InRightPadW == 0) - { - if(!((Ho * Wo) % src_vector_lengths[3] == 0)) - return false; - } - else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 && - conv_problem_desc.InRightPadW == 0) - { - if(!(Wo % src_vector_lengths[3] == 0)) - return false; - } - else - { - if(!(src_vector_lengths[3] == 1)) - return false; - } - - // check Dst vectorization: {GN11, GK1} are LDS vector dims - if(dst_vector_lengths[4] == GK1) - { // vectorize on {GN11, GK1} - if(!(GN11 % dst_vector_lengths[3] == 0)) - return false; - } - else - { // vectorize on {GK1} only - if(!(dst_vector_lengths[3] == 1)) - return false; - - if(!(GK1 % dst_vector_lengths[4] == 0)) - return false; - } - } - - // check blockwise GEMM - { - const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(), - compile_param.BM10BN10ThreadClusterBM10Xs.end(), - 1, - std::multiplies{}); - - const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(), - compile_param.BM10BN10ThreadClusterBN10Xs.end(), - 1, - std::multiplies{}); - - if(!(compile_param.BlockSize == BM10 * BN10)) - return false; - - const int BM = GM0 * GM11; - const int BN = GN0 * GN11; - - const int BM1 = BM10 * BM11; - const int BN1 = BN10 * BN11; - - if(!(BM % BM1 == 0 && BN % BN1 == 0)) - return false; - - const int BM0 = BM / BM1; - const int BN0 = BN / BN1; - - // blockwise GEMM currently only support BM0 == 2 && BN0 == 2 - if(!(BM0 == 2 && BN0 == 2)) - return false; - - if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0)) - return false; - } - - // check C threadwise copy - { - // {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim - const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector; - - // check slice length vs Dst vector length: - if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0)) - return false; - - // check Dst memory layout related vectorization: - if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0)) - return false; - } - - return true; - }; - - static int GetBlockSize(const ConvolutionProblemDescriptor&, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) - { - return compile_param.BlockSize; - } - - static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) - { - const int N = conv_problem_desc.N; - const int K = conv_problem_desc.K; - const int Ho = conv_problem_desc.Ho; - const int Wo = conv_problem_desc.Wo; - - const int N0 = compile_param.GN0; - const int N1 = N / N0; - - const int GM1 = K; - const int GN1 = N1 * Ho * Wo; - - const int GM11 = compile_param.GM1PerBlockGM11; - const int GN11 = compile_param.GN1PerBlockGN11; - - const int GM10 = GM1 / GM11; - const int GN10 = GN1 / GN11; - - return GM10 * GN10; - } - - static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&, - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&) - { - // workspace is used for save transformed tensor descritpors created by prepare kernel - return 4096L; - } - - static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; } - - static auto GetTunableList() - { - return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw(); - } -}; - -} // namespace driver -} // namespace ck -#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 58fe588a..00000000 --- a/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP -#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP - -struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw -{ - int BlockSize; - - int MPerBlock; - int NPerBlock; - int KPerBlock; - - int M1PerThread; - int N1PerThread; - int KPerThread; - - int M1N1ThreadClusterM10; - int M1N1ThreadClusterN10; - int M1N1ThreadClusterM11; - int M1N1ThreadClusterN11; - - std::array ABlockTransferThreadSliceLengths_K_M0_M1; - std::array ABlockTransferThreadClusterLengths_K_M0_M1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int ABlockTransferSrcVectorDim; - int ABlockTransferSrcScalarPerVector; - int ABlockTransferDstScalarPerVector_M1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K_N0_N1; - std::array BBlockTransferThreadClusterLengths_K_N0_N1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int BBlockTransferSrcVectorDim; - int BBlockTransferSrcScalarPerVector; - int BBlockTransferDstScalarPerVector_N1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int CThreadTransferSrcDstVectorDim; - int CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw - default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = { - 256, 128, 128, 8, 4, 4, 1, - 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, - {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, - {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, - 5, 1}; -#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 361f6e4a..00000000 --- a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP -#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP - -struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw -{ - int BlockSize; - - int MPerBlock; - int NPerBlock; - int KPerBlock; - - int MPerXDL; - int NPerXDL; - int K1; - - int MRepeat; - int NRepeat; - - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int ABlockTransferSrcVectorDim; - int ABlockTransferSrcScalarPerVector; - int ABlockTransferDstScalarPerVector_K1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int BBlockTransferSrcVectorDim; - int BBlockTransferSrcScalarPerVector; - int BBlockTransferDstScalarPerVector_K1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int CThreadTransferSrcDstVectorDim; - int CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw - default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { - 256, // BlockSize - 128, // MPerBlock, - 128, // NPerBlock, - 4, // KPerBlock, - 32, // MPerXDL, - 32, // NPerXDL, - 4, // K1, - 2, // MRepeat, - 2, // NRepeat, - {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, - {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, - {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector, - 4, // ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, - {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, - {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // BBlockTransferSrcAccessOrder, - 1, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_K1 - false, // BThreadTransferSrcResetCoordinateAfterRun - {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder - 7, // CThreadTransferSrcDstVectorDim, - 1 // CThreadTransferDstScalarPerVector -}; -#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 263c21a1..00000000 --- a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP -#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP - -struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk -{ - int BlockSize; - - int MPerBlock; - int NPerBlock; - int KPerBlock; - - int MPerWave; - int NPerWave; - int K1; - - int MRepeat; - int NRepeat; - - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int ABlockTransferSrcVectorDim; - int ABlockTransferSrcScalarPerVector; - int ABlockTransferDstScalarPerVector_K1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int BBlockTransferSrcVectorDim; - int BBlockTransferSrcScalarPerVector; - int BBlockTransferDstScalarPerVector_K1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int CThreadTransferSrcDstVectorDim; - int CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk - default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { - 256, // BlockSize - 128, // MPerBlock, - 128, // NPerBlock, - 4, // KPerBlock, - 32, // MPerWave, - 32, // NPerWave, - 4, // K1, - 2, // MRepeat, - 2, // NRepeat, - {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, - {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, - {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim - 4, // ABlockTransferSrcScalarPerVector, - 4, // ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, - {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, - {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim - 4, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_K1 - false, // BThreadTransferSrcResetCoordinateAfterRun - {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder - 7, // CThreadTransferSrcDstVectorDim, - 1 // CThreadTransferDstScalarPerVector -}; -#endif diff --git a/host/solver/include/convolution_problem_descriptor.hpp b/host/solver/include/convolution_problem_descriptor.hpp deleted file mode 100644 index 8c0ecbee..00000000 --- a/host/solver/include/convolution_problem_descriptor.hpp +++ /dev/null @@ -1,81 +0,0 @@ -#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR -#define CONVOLUTION_PROBLEM_DESCRIPTOR - -namespace ck { -namespace driver { - -struct ConvolutionProblemDescriptor -{ - ConvolutionProblemDescriptor() = default; - - ConvolutionProblemDescriptor(int N_, - int K_, - int C_, - int Y_, - int X_, - int Hi_, - int Wi_, - int Ho_, - int Wo_, - int ConvStrideH_, - int ConvStrideW_, - int ConvDilationH_, - int ConvDilationW_, - int InLeftPadH_, - int InLeftPadW_, - int InRightPadH_, - int InRightPadW_, - ck::DataTypeEnum_t InDataTypeEnum_, - ck::DataTypeEnum_t WeiDataTypeEnum_, - ck::DataTypeEnum_t OutDataTypeEnum_) - : N{N_}, - K{K_}, - C{C_}, - Y{Y_}, - X{X_}, - Hi{Hi_}, - Wi{Wi_}, - Ho{Ho_}, - Wo{Wo_}, - ConvStrideH{ConvStrideH_}, - ConvStrideW{ConvStrideW_}, - ConvDilationH{ConvDilationH_}, - ConvDilationW{ConvDilationW_}, - InLeftPadH{InLeftPadH_}, - InLeftPadW{InLeftPadW_}, - InRightPadH{InRightPadH_}, - InRightPadW{InRightPadW_}, - InDataTypeEnum{InDataTypeEnum_}, - WeiDataTypeEnum{WeiDataTypeEnum_}, - OutDataTypeEnum{OutDataTypeEnum_} - { - } - - int N; - int K; - int C; - int Y; - int X; - int Hi; - int Wi; - int Ho; - int Wo; - int ConvStrideH; - int ConvStrideW; - int ConvDilationH; - int ConvDilationW; - int InLeftPadH; - int InLeftPadW; - int InRightPadH; - int InRightPadW; - - ck::DataTypeEnum_t InDataTypeEnum; - ck::DataTypeEnum_t WeiDataTypeEnum; - ck::DataTypeEnum_t OutDataTypeEnum; - - std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; } -}; - -} // namespace driver -} // namespace ck -#endif diff --git a/host/solver/include/solver_common.hpp b/host/solver/include/solver_common.hpp deleted file mode 100644 index d1792f76..00000000 --- a/host/solver/include/solver_common.hpp +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef CK_SOLVER_COMMON_HPP -#define CK_SOLVER_COMMON_HPP - -namespace ck { -namespace driver { - -// greatest common divisor, aka highest common factor -inline int gcd(int x, int y) -{ - if(x < 0) - { - return gcd(-x, y); - } - else if(y < 0) - { - return gcd(x, -y); - } - else if(x == y || x == 0) - { - return y; - } - else if(y == 0) - { - return x; - } - else if(x > y) - { - return gcd(x % y, y); - } - else - { - return gcd(x, y % x); - } -} - -template = 2, bool>::type = false> -auto gcd(X x, Ys... ys) -{ - return gcd(x, gcd(ys...)); -} - -} // namespace driver -} // namespace ck -#endif diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp new file mode 100644 index 00000000..2721f7d1 --- /dev/null +++ b/include/ck/ck.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif + +#define CK_TIME_KERNEL 1 + +// constant address space for kernel parameter +// https://llvm.org/docs/AMDGPUUsage.html#address-spaces +#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 1 + +#ifdef CK_USE_LAUNCH_BOUNDS +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 2 +#endif + +// check GPU target +#ifdef __HIP_DEVICE_COMPILE__ +#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx1030__)) +#error Not supported target +#endif +#endif + +// buffer resource +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_BUFFER_RESOURCE_3RD_DWORD -1 +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx1030__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#endif + +// FMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing +#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code +#define CK_USE_AMD_V_MAC_F32 +#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx1030__) // for GPU code +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8 +#endif + +// MFMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_MFMA +#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_MFMA +#endif + +#if defined(__gfx90a__) +#define CK_USE_AMD_MFMA_BF16_1K_OP +#endif + +// buffer load +#define CK_USE_AMD_BUFFER_LOAD 1 + +// buffer store +#define CK_USE_AMD_BUFFER_STORE 1 + +// buffer atomic add: integer +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 + +// buffer atomic add: floating point +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#else // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 +#endif + +#if defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + +// inline asm +#define CK_USE_AMD_INLINE_ASM 1 + +// inner product (DLOP) +#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 + +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 + +// experimental feature: multi index implemented as array +#define CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + +// experimental feature: static tensor descriptor +#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 + +// experimental feature: buffer load/store/atomic-add/ OOB trick +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting. Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter for each usage +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 + +// experimental feature: in-regsiter sub-dword transpose +#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 + +// experimental feature: merge transformation use magic number division +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 + +// experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from +// pointer of scalar +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 + +// experimental feature: use __builtin_memcpy instead of union to do bit_cast +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 + +// experimental feature: optimize for inter-wave scheduling policy +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0 +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 + +// workaround: compiler crash when compiling recursive lambda +#define CK_WORKAROUND_SWDEV_275126 1 + +// workaround: compiler crash when using buffer load/store for i8 +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 + +// workaround: compiler gnerating inefficient ds_write instructions +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 + +// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some +// tuning parameter +#define CK_WORKAROUND_SWDEV_325164 1 + +namespace ck { + +enum struct InMemoryDataOperationEnum +{ + Set, + AtomicAdd, + AtomicMax, + Add +}; + +// FIXME: use regular Sequence and remove this +template +struct InMemoryDataOperationEnumSequence +{ + static constexpr int mSize = sizeof...(Is); + + __host__ __device__ static constexpr InMemoryDataOperationEnum At(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set}; + return mData[I]; + } +}; + +// index type +using index_t = int32_t; +using long_index_t = int64_t; + +} // namespace ck diff --git a/include/ck/device_utility/device_prop.hpp b/include/ck/device_utility/device_prop.hpp new file mode 100644 index 00000000..e2cbdb73 --- /dev/null +++ b/include/ck/device_utility/device_prop.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +namespace ck { + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + + // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + static std::map device_name_map = { + {"Ellesmere", "gfx803"}, + {"Baffin", "gfx803"}, + {"RacerX", "gfx803"}, + {"Polaris10", "gfx803"}, + {"Polaris11", "gfx803"}, + {"Tonga", "gfx803"}, + {"Fiji", "gfx803"}, + {"gfx800", "gfx803"}, + {"gfx802", "gfx803"}, + {"gfx804", "gfx803"}, + {"Vega10", "gfx900"}, + {"gfx901", "gfx900"}, + {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, + }; + + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + + auto match = device_name_map.find(name); + if(match != device_name_map.end()) + return match->second; + return name; +} + +} // namespace ck diff --git a/include/ck/device_utility/hip_check_error.hpp b/include/ck/device_utility/hip_check_error.hpp new file mode 100644 index 00000000..d3dc8eaf --- /dev/null +++ b/include/ck/device_utility/hip_check_error.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ + << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} diff --git a/include/ck/device_utility/kernel_launch.hpp b/include/ck/device_utility/kernel_launch.hpp new file mode 100644 index 00000000..5879f999 --- /dev/null +++ b/include/ck/device_utility/kernel_launch.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/device_utility/hip_check_error.hpp" + +template +float launch_and_time_kernel(const StreamConfig& stream_config, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + const int nrepeat = 10; + + printf("Warm up 1 time\n"); + + // warm up + kernel<<>>(args...); + + printf("Start running %d times...\n", nrepeat); + + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + kernel<<>>(args...); + + return 0; + } +#else + kernel<<>>(args...); + + return 0; +#endif +} diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp similarity index 74% rename from composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp index 09ea16fa..db8e48df 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP @@ -7,9 +10,9 @@ namespace ck { -// Number of GEMMs = YTilda * XTilda +// Number of GEMMs = YTilde * XTilde // GemmM = C -// GemmN = N * HTildaSlice * WTildaSlice +// GemmN = N * HTildeSlice * WTildeSlice // GemmK = K * YDotSlice * XDotSlice template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( @@ -30,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - Number, - Number, + Number, + Number, Number) { constexpr auto I0 = Number<0>{}; @@ -40,8 +43,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( constexpr auto I3 = Number<3>{}; constexpr auto GemmK1 = Number{}; - constexpr auto IYTilda = Number{}; - constexpr auto IXTilda = Number{}; + constexpr auto IYTilde = Number{}; + constexpr auto IXTilde = Number{}; const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); @@ -71,55 +74,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - const auto YTilda = ConvStrideH / GcdStrideDilationH; - const auto XTilda = ConvStrideW / GcdStrideDilationW; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; - const auto YDot = math::integer_divide_ceil(Y, YTilda); - const auto XDot = math::integer_divide_ceil(X, XTilda); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); - const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - const auto IHTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); - const auto IWTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - const auto IHTildaSliceEnd = - math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildaSliceEnd = - math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; - const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde); const auto K1 = GemmK1; const auto K0 = K / K1; // weight tensor - const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( wei_k_y_x_c_grid_desc, make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilda), + make_embed_transform(make_tuple(YDot, YTilde), make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilda), + make_embed_transform(make_tuple(XDot, XTilde), make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), + make_freeze_transform(IYTilde), + make_freeze_transform(IXTilde), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -163,25 +166,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilda), + make_embed_transform(make_tuple(YDot, HTilde), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilda), + make_embed_transform(make_tuple(XDot, WTilde), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( - out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, make_tuple(make_pass_through_transform(N), make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_unmerge_transform(make_tuple(K0, K1))), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -198,17 +201,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( #if 1 const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); @@ -224,24 +227,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilda, HTilda), + make_embed_transform(make_tuple(YTilde, HTilde), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilda, WTilda), + make_embed_transform(make_tuple(XTilde, WTilde), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilda), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(IXTilda), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_freeze_transform(IYTilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IXTilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -257,9 +260,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( Sequence<3>{})); const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildaslice_wtildaslice_c_grid_desc, + in_n_htildeslice_wtildeslice_c_grid_desc, make_tuple(make_pass_through_transform(C), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp similarity index 81% rename from composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index fa78d769..5391b595 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP @@ -10,8 +13,8 @@ namespace ck { // A: out // B: wei // C: in -// Number of GEMMs = YTilda * XTilda -// GemmM = N * HTildaSlice * WTildaSlice +// Number of GEMMs = YTilde * XTilde +// GemmM = N * HTildeSlice * WTildeSlice // GemmN = C // GemmK = K * YDotSlice * XDotSlice template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( @@ -33,8 +36,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - IYTilda i_ytilda, - IXTilda i_xtilda, + IYTilde i_ytilde, + IXTilde i_xtilde, Number) { constexpr auto I0 = Number<0>{}; @@ -72,32 +75,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - const auto YTilda = ConvStrideH / GcdStrideDilationH; - const auto XTilda = ConvStrideW / GcdStrideDilationW; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; - const auto YDot = math::integer_divide_ceil(Y, YTilda); - const auto XDot = math::integer_divide_ceil(X, XTilda); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); - const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - const auto IHTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); - const auto IWTildaSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - const auto IHTildaSliceEnd = - math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildaSliceEnd = - math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; - const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto K1 = GemmK1; const auto K0 = K / K1; @@ -113,25 +116,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilda), + make_embed_transform(make_tuple(YDot, HTilde), make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilda), + make_embed_transform(make_tuple(XDot, WTilde), make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( - out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, make_tuple(make_pass_through_transform(N), make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_unmerge_transform(make_tuple(K0, K1))), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -148,41 +151,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( #if 1 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(K1)), make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #endif // B: weight tensor - const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( wei_k_y_x_c_grid_desc, make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilda), + make_embed_transform(make_tuple(YDot, YTilde), make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilda), + make_embed_transform(make_tuple(XDot, XTilde), make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilda), - make_freeze_transform(i_xtilda), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -225,24 +228,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilda, HTilda), + make_embed_transform(make_tuple(YTilde, HTilde), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilda, WTilda), + make_embed_transform(make_tuple(XTilde, WTilde), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilda), - make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(i_xtilda), - make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -258,8 +261,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( Sequence<3>{})); const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildaslice_wtildaslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_pass_through_transform(C)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp index e533ad91..bb1dc239 100644 --- a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp index 949f044b..ca530934 100644 --- a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp index 213e1d61..e960f90c 100644 --- a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index f1e1826d..052bab42 100644 --- a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp index 02e61c0e..c301a9e0 100644 --- a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000..41267536 --- /dev/null +++ b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Do * Ho * Wo +// GemmN = K +// GemmK = Z * Y * X * C +template +__host__ __device__ constexpr auto +transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + const TensorDescriptor& in_grid_desc_n_di_hi_wi_c, + const TensorDescriptor& wei_k_z_y_x_c_grid_desc, + const TensorDescriptor& out_n_do_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_grid_desc_n_di_hi_wi_c.GetLength(I0); + const auto K = out_n_do_ho_wo_k_grid_desc.GetLength(I4); + const auto C = in_grid_desc_n_di_hi_wi_c.GetLength(I4); + + const auto Di = in_grid_desc_n_di_hi_wi_c.GetLength(I1); + const auto Hi = in_grid_desc_n_di_hi_wi_c.GetLength(I2); + const auto Wi = in_grid_desc_n_di_hi_wi_c.GetLength(I3); + + const auto Do = out_n_do_ho_wo_k_grid_desc.GetLength(I1); + const auto Ho = out_n_do_ho_wo_k_grid_desc.GetLength(I2); + const auto Wo = out_n_do_ho_wo_k_grid_desc.GetLength(I3); + + const auto Z = wei_k_z_y_x_c_grid_desc.GetLength(I1); + const auto Y = wei_k_z_y_x_c_grid_desc.GetLength(I2); + const auto X = wei_k_z_y_x_c_grid_desc.GetLength(I3); + + const auto ConvStrideD = conv_strides[I0]; + const auto ConvStrideH = conv_strides[I1]; + const auto ConvStrideW = conv_strides[I2]; + + const auto ConvDilationD = conv_dilations[I0]; + const auto ConvDilationH = conv_dilations[I1]; + const auto ConvDilationW = conv_dilations[I2]; + + const auto InLeftPadD = in_left_pads[I0]; + const auto InLeftPadH = in_left_pads[I1]; + const auto InLeftPadW = in_left_pads[I2]; + + const auto InRightPadD = in_right_pads[I0]; + const auto InRightPadH = in_right_pads[I1]; + const auto InRightPadW = in_right_pads[I2]; + + const auto GemmM = N * Do * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Z * Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}, Sequence<7>{})); + + const auto in_grid_desc_gemmk_gemmm = + transform_tensor_descriptor(in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_gemmk0_gemmm_gemmk1 = + transform_tensor_descriptor(in_grid_desc_gemmk_gemmm, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_grid_desc_gemmk_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Z * Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_grid_desc_gemmk0_gemmn_gemmk1 = + transform_tensor_descriptor(wei_grid_desc_gemmk_gemmn, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + // out_n_do_ho_wo_k_grid_desc, + // make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + // make_pass_through_transform(K)), + // make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<3>{}), + // make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_grid_desc_gemmk0_gemmm_gemmk1, + wei_grid_desc_gemmk0_gemmn_gemmk1, + out_grid_desc_gemmm_gemmn); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 99% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp index 093a4625..381f9ac9 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 9aa27884..ebfaabb0 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp index 16ae8b47..6e576d69 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp index e81c87d0..13e1bf25 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp similarity index 96% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index b0b07505..088d14b2 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP @@ -21,8 +24,7 @@ template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( const TensorDescriptor& in_n_hi_wi_c_grid_desc, const TensorDescriptor& wei_k_y_x_c_grid_desc, const TensorDescriptor& out_n_ho_wo_k_grid_desc, diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp similarity index 98% rename from composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp rename to include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp index f5cb7f48..a6785d56 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp new file mode 100644 index 00000000..95076606 --- /dev/null +++ b/include/ck/stream_config.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; +}; diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp new file mode 100644 index 00000000..fee679f9 --- /dev/null +++ b/include/ck/tensor/static_tensor.hpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATIC_TENSOR_HPP +#define CK_STATIC_TENSOR_HPP + +namespace ck { + +// StaticTensor for Scalar +template ::type = false> +struct StaticTensor +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {} + + __host__ __device__ constexpr StaticTensor(T invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // read access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const T& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // write access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr T& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + StaticBuffer data_; + static constexpr T zero_scalar_value_ = T{0}; + const T invalid_element_scalar_value_; + T ignored_element_scalar_; +}; + +// StaticTensor for vector +template ::type = false> +struct StaticTensorTupleOfVectorBuffer +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + static constexpr index_t num_of_vector_ = + math::integer_divide_ceil(element_space_size_, ScalarPerVector); + + using V = vector_type; + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() + : invalid_element_scalar_value_{0} + { + } + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // Get S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const S& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // Set S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr S& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + // Get X + // Idx is for S, not X. Idx should be aligned with X + template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr X GetAsType(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_.template GetAsType(Number{}); + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + // TODO: is this right way to initialize a vector? + return X{0}; + } + else + { + // TODO: is this right way to initialize a vector? + return X{invalid_element_scalar_value_}; + } + } + } + + // Set X + // Idx is for S, not X. Idx should be aligned with X + template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr void SetAsType(Idx, X x) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + data_.template SetAsType(Number{}, x); + } + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr V& GetVectorTypeReference(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + StaticBufferTupleOfVector data_; + static constexpr S zero_scalar_value_ = S{0}; + const S invalid_element_scalar_value_ = S{0}; + S ignored_element_scalar_; +}; + +template ::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc) +{ + return StaticTensor{}; +} + +template < + AddressSpaceEnum AddressSpace, + typename T, + typename TensorDesc, + typename X, + typename enable_if::type = false, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value) +{ + return StaticTensor{invalid_element_value}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp similarity index 81% rename from composable_kernel/include/tensor_description/cluster_descriptor.hpp rename to include/ck/tensor_description/cluster_descriptor.hpp index d69bfb70..0c9ea2ff 100644 --- a/composable_kernel/include/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -1,8 +1,10 @@ -#ifndef CK_CLUSTER_DESCRIPTOR_HPP -#define CK_CLUSTER_DESCRIPTOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_adaptor.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { @@ -30,4 +32,3 @@ __host__ __device__ constexpr auto make_cluster_descriptor( } } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp similarity index 95% rename from composable_kernel/include/tensor_description/multi_index_transform.hpp rename to include/ck/tensor_description/multi_index_transform.hpp index 1a25e99f..4e4d7593 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,8 +1,10 @@ -#ifndef CK_MULTI_INDEX_TRANSFORM_HPP -#define CK_MULTI_INDEX_TRANSFORM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "multi_index.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/multi_index.hpp" namespace ck { @@ -30,7 +32,8 @@ struct PassThrough __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } template - __host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) { static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, "wrong! inconsistent # of dimension"); @@ -1708,7 +1711,8 @@ struct Vectorize __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } template - __host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const { static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, "wrong! inconsistent # of dimension"); @@ -1860,5 +1864,91 @@ struct Slice } }; +/* + * \brief lower_idx = upper_idx % modulus. + * TODO: Need an improved implementation since the modulo operation is expensive. + */ +template +struct Modulo +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + using UpLengths = decltype(make_tuple(UpLength{})); + + Modulus modulus_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Modulo() = default; + + __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length) + : modulus_{modulus}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& up_idx, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + const auto idx_low_old = idx_low; + idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_; + idx_diff_low(I0) = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Modulus, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp similarity index 86% rename from composable_kernel/include/tensor_description/multi_index_transform_helper.hpp rename to include/ck/tensor_description/multi_index_transform_helper.hpp index 9a737991..044a9037 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,8 +1,10 @@ -#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP -#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "multi_index_transform.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" namespace ck { @@ -98,6 +100,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i return Freeze{low_idx}; } +template +__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx) +{ + return Insert{up_idx}; +} + template __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, const SliceBegin& slice_begin, @@ -113,5 +121,10 @@ __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& ve return Vectorize{vector_size, up_length}; } +template +__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, + const UpLength& up_length) +{ + return Modulo{modulus, up_length}; +} } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp similarity index 95% rename from composable_kernel/include/tensor_description/tensor_adaptor.hpp rename to include/ck/tensor_description/tensor_adaptor.hpp index 50a8088b..d42e0a6f 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -1,9 +1,11 @@ -#ifndef CK_TENSOR_ADAPTOR_HPP -#define CK_TENSOR_ADAPTOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" namespace ck { @@ -136,7 +138,11 @@ struct TensorAdaptor using ElementSize = remove_cv_t; public: +#if 0 // workaround compiler complaint about constexpr __host__ __device__ constexpr TensorAdaptor() = default; +#else + __host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {} +#endif __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} @@ -151,6 +157,20 @@ struct TensorAdaptor __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } +#if 0 // debug + template + __host__ __device__ constexpr index_t GetTopDimensionLength(Number idim) const + { + // TODO: not implemented + } + + template + __host__ __device__ constexpr index_t GetBottomDimensionLength(Number idim) const + { + // TODO: not implemented + } +#endif + template __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const { @@ -460,4 +480,3 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. } } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp similarity index 97% rename from composable_kernel/include/tensor_description/tensor_descriptor.hpp rename to include/ck/tensor_description/tensor_descriptor.hpp index 8f6a5a3e..1e69736e 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -1,8 +1,10 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HPP -#define CK_TENSOR_DESCRIPTOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "multi_index_transform.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" namespace ck { @@ -111,7 +113,14 @@ struct TensorDescriptor using ElementSize = remove_cv_t; public: +#if 0 // workaround compiler complaint about constexpr __host__ __device__ constexpr TensorDescriptor() = default; +#else + __host__ __device__ constexpr TensorDescriptor() + : transforms_{}, element_size_{}, element_space_size_{} + { + } +#endif __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, ElementSpaceSize element_space_size) @@ -307,6 +316,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, { // sanity check { + static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() && + NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), + "wrong! inconsitent number of transform"); + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, NewLowerDimensionOldVisibleIdss{}); @@ -593,4 +606,3 @@ using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp similarity index 84% rename from composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp rename to include/ck/tensor_description/tensor_descriptor_helper.hpp index ad75f924..461aae72 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,9 +1,11 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP -#define CK_TENSOR_DESCRIPTOR_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "multi_index_transform_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" namespace ck { @@ -35,6 +37,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt } #endif +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> template ::type = false> @@ -68,10 +76,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple{}, Number<1>{}); + const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{}); #else const auto element_space_size = - calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); + calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); #endif return TensorDescriptor, @@ -82,9 +90,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> template __host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple& lengths) @@ -100,7 +111,7 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths) constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{}); + const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); return TensorDescriptor, remove_cv_t, @@ -110,6 +121,12 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths) element_space_size}; } +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) Number<> template __host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align align) @@ -146,4 +163,3 @@ make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align ali } } // namespace ck -#endif diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp new file mode 100644 index 00000000..e9a990d8 --- /dev/null +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/statically_indexed_array_multi_index.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template // # of scalars per access in each dimension +struct SpaceFillingCurve +{ + static constexpr index_t nDim = TensorLengths::Size(); + + using Index = MultiIndex; + + static constexpr index_t ScalarPerVector = + reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{}); + + static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(ordered_access_lengths)), + make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr index_t GetNumOfAccess() + { + static_assert(TensorLengths::Size() == ScalarsPerAccess::Size()); + static_assert(TensorLengths{} % ScalarsPerAccess{} == + typename uniform_sequence_gen::type{}); + + return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / + ScalarPerVector; + } + + template + static __device__ __host__ constexpr auto GetStepBetween(Number, + Number) + { + static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0"); + static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0"); + + constexpr auto idx_begin = GetIndex(Number{}); + constexpr auto idx_end = GetIndex(Number{}); + return idx_end - idx_begin; + } + + template + static __device__ __host__ constexpr auto GetForwardStep(Number) + { + static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0"); + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr auto GetBackwardStep(Number) + { + static_assert(AccessIdx1d > 0, "1D index should be larger than 0"); + + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr Index GetIndex(Number) + { +#if 0 + /* + * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected. + */ + constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number{})); +#else + + constexpr auto access_strides = container_reverse_exclusive_scan( + ordered_access_lengths, math::multiplies{}, Number<1>{}); + + constexpr auto idx_1d = Number{}; + // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the + // idim-th element of multidimensional index. + // All constexpr variables have to be captured by VALUE. + constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr + { + constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr + { + auto res = idx_1d.value; + auto id = 0; + + static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { + id = res / access_strides[kdim].value; + res -= id * access_strides[kdim].value; + }); + + return id; + }; + + constexpr auto id = compute_index_impl(idim); + return Number{}; + }; + + constexpr auto ordered_access_idx = generate_tuple(compute_index, Number{}); +#endif + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto idim) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, idim, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(idim) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate multi-dim tensor index + auto idx_md = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto idim) { + ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - + ordered_access_idx[idim]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + ScalarsPerAccess{}; + }(); + return idx_md; + } + + // FIXME: rename this function + template + static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number) + { + constexpr auto idx = GetIndex(Number{}); + + return generate_tuple([&](auto i) { return Number{}; }, Number{}); + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp similarity index 95% rename from composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 26ca0bf1..8b1b7be1 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,10 +1,12 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_adaptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_contraction_dlops.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp" namespace ck { @@ -41,7 +43,7 @@ template ::type = false> -struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -148,7 +150,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ @@ -175,6 +177,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B "wrong!"); // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); static_assert(BM0 == 2 && BN0 == 2, "wrong"); } @@ -220,13 +223,13 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, "wrong"); - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); constexpr auto threadwise_contraction = - ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< FloatA, FloatB, FloatC, @@ -407,4 +410,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B }; } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp similarity index 98% rename from composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp index 35ff66a2..33120bd8 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP @@ -207,9 +210,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, "wrong"); - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_k_m0_m1_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_k_n0_n1_thread_desc_.GetElementSpaceSize()); constexpr auto threadwise_gemm = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp new file mode 100644 index 00000000..f4565572 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0); + static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1); + static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2); + + static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0); + static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2); + static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3); + + static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + static constexpr auto b_thread_mtx_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number<1>{}, + Number{}, + Number{}, + Number{})); + + static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() + : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)} + { + static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() && + BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert( + ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) && + ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4), + "wrong! E dimension not consistent\n"); + + static_assert(E1 % EPerThreadLoop == 0, ""); + static_assert(KPerThread % KPerThreadLoop == 0, ""); + + static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 && + WoPerBlock % WoPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = KPerBlock / KPerThread; + constexpr auto HThreadCluster = HoPerBlock / HoPerThread; + constexpr auto WThreadCluster = WoPerBlock / WoPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + } + + __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths() + { + return Sequence{}; + } + + __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id) + { + constexpr auto K0 = KPerBlock / KPerThread; + constexpr auto N0 = I1; + constexpr auto H0 = HoPerBlock / HoPerThread; + constexpr auto W0 = WoPerBlock / WoPerThread; + + constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_k_n_h_w_thread_cluster_idx = + c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex( + make_multi_index(thread_id)); + + return c_k_n_h_w_thread_cluster_idx; + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{}; + + // thread A buffer for GEMM + StaticBuffer + a_thread_buf; + + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; + + static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) { + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin, I0), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0, I0), + a_thread_buf); + + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, I0, I0, I0), + c_thread_buf, + make_tuple(k_begin, I0, I0, I0)); + }); + }); + } + + template + __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) + { + a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx); + } + + private: + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + E2, + E2>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp new file mode 100644 index 00000000..9720db4a --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -0,0 +1,588 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, + Interwave, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + return LoopScheduler::Interwave; +#else + return LoopScheduler::Default; +#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro +// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in +// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the +// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 +template +struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + using Base::a_block_desc_m0_m1_m2_k; + using Base::A_K1; + using Base::b_block_desc_n0_n1_n2_k; + using Base::B_K1; + using Base::c_thread_buf_; + using Base::c_thread_desc_; + using Base::CalculateAThreadOriginDataIndex; + using Base::CalculateBThreadOriginDataIndex; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + + // 2-wave optimized blockwise gemm + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, k), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, k), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except + // the first, as we can shorten non-MAC cluster a bit and there's no observable negative + // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids + // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the + // chance of latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { + asm volatile("s_barrier" ::); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + + protected: + // A[M0, M1, M2, KPerInnerLoop] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + // B[N0, N1, N2, KPerInnerLoop] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + +#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +}; + +template +constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp similarity index 85% rename from composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 6b2d2d52..03e4d42d 100644 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -1,11 +1,13 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp" namespace ck { @@ -14,7 +16,7 @@ namespace ck { // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate template -struct BlockwiseTensorSliceTransfer_v4r1 +struct BlockwiseTensorSliceTransfer_v5r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, + __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_block_slice_origin, const DstDesc& dst_desc, const Index& dst_block_slice_origin) @@ -45,8 +47,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) { - static_assert(nDim == remove_reference_t>::GetNumOfDimension() && - nDim == remove_reference_t>::GetNumOfDimension() && + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() && @@ -75,14 +77,13 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf); } } @@ -134,7 +135,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3r1> +struct PartitionedBlockwiseReduction +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + +// clang-format off +// Assume: +// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize +// 3) in_out_value/in_out_index is the input data in vgpr from each thread +// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread +// clang-format on +template < + typename AccDataType, + typename IndexDataType, + index_t BlockSize, + typename ThreadClusterLengths_M_K, + typename ThreadClusterArrangeOrder, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> +struct PartitionedBlockwiseReductionWithIndex +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + // This interface accumulates on both data values and indices + template + __device__ static void Reduce(BufferType& work_val_buffer, + IdxBufferType& work_idx_buffer, + AccDataType& in_out_value, + IndexDataType& in_out_index) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + static_assert(is_same{}, + "Buffer data type should be consistent as IndexDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << I(); + + if(thread_k_cluster_id % (indOffset * 2) == 0) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_val_buffer[offset1]; + AccDataType opData2 = work_val_buffer[offset2]; + IndexDataType currIndex1 = work_idx_buffer[offset1]; + IndexDataType currIndex2 = work_idx_buffer[offset2]; + + Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2); + work_val_buffer(offset1) = opData1; + work_idx_buffer(offset1) = currIndex1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_val_buffer[offset]; + in_out_index = work_idx_buffer[offset]; + }; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 00000000..0e5dfb35 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 00000000..5c47a49b --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 00000000..aa33fc08 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. It does not keep reference to tensor descriptor +// 3. Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 00000000..eb5f589a --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + src2_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc2SliceOrigin( + src2_desc, src2_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run( + src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp new file mode 100644 index 00000000..3bd78063 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs); + } + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp new file mode 100644 index 00000000..6a226b0c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION +#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardDataSpecialization +{ + Default, + Filter1x1Stride1Pad0, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp new file mode 100644 index 00000000..f4607ee6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardWeightSpecialization +{ + Default, + Filter1x1Stride1Pad0, + Filter1x1Pad0, + OddC, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp new file mode 100644 index 00000000..ea60a4e6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CONVOLUTION_FORWARD_SPECIALIZATION +#define CONVOLUTION_FORWARD_SPECIALIZATION + +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionForwardSpecialization +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + OddC, +}; + +inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) +{ + switch(s) + { + case ConvolutionForwardSpecialization::Default: return "Default"; + case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionForwardSpecialization::OddC: return "OddC"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp new file mode 100644 index 00000000..c228045b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor> +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + + using Gridwise5AryEltwise = Gridwise5AryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_c, + const DDataType* p_d, + const EDataType* p_e, + FDataType* p_f, + const std::vector& lengths, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, + const std::vector& d_strides, + const std::vector& e_strides, + const std::vector& f_strides, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + p_d_(p_d), + p_e_(p_e), + p_f_(p_f), + lengths_(lengths), + a_strides_(a_strides), + b_strides_(b_strides), + c_strides_(c_strides), + d_strides_(d_strides), + e_strides_(e_strides), + f_strides_(f_strides), + functor_(functor), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_); + b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_); + c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_); + d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_); + e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_); + f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + const CDataType* p_c_; + const DDataType* p_d_; + const EDataType* p_e_; + FDataType* p_f_; + std::vector lengths_; + AGridDesc_M a_grid_desc_m_; + BGridDesc_M b_grid_desc_m_; + CGridDesc_M c_grid_desc_m_; + DGridDesc_M d_grid_desc_m_; + EGridDesc_M e_grid_desc_m_; + FGridDesc_M f_grid_desc_m_; + std::vector a_strides_; + std::vector b_strides_; + std::vector c_strides_; + std::vector d_strides_; + std::vector e_strides_; + std::vector f_strides_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_5ary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.p_d_, + arg.p_e_, + arg.p_f_, + arg.a_grid_desc_m_, + arg.b_grid_desc_m_, + arg.c_grid_desc_m_, + arg.d_grid_desc_m_, + arg.e_grid_desc_m_, + arg.f_grid_desc_m_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.size() != NDim) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = MPerThread % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector)) + return false; + + return true; + }; + + static auto MakeArgument(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, + std::vector d_strides, + std::vector e_strides, + std::vector f_strides, + ElementwiseFunctor functor) + { + return Argument{static_cast(p_inputs[0]), + static_cast(p_inputs[1]), + static_cast(p_inputs[2]), + static_cast(p_inputs[3]), + static_cast(p_inputs[4]), + static_cast(p_outputs[0]), + lengths, + a_strides, + b_strides, + c_strides, + d_strides, + e_strides, + f_strides, + functor}; + } + + std::unique_ptr + MakeArgumentPointer(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector> input_strides, + std::vector> output_strides, + ElementwiseFunctor functor) override + { + return std::make_unique(static_cast(p_inputs[0]), + static_cast(p_inputs[1]), + static_cast(p_inputs[2]), + static_cast(p_inputs[3]), + static_cast(p_inputs[4]), + static_cast(p_outputs[0]), + lengths, + input_strides[0], + input_strides[1], + input_strides[2], + input_strides[3], + input_strides[4], + output_strides[0], + functor); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Device5aryElementwise" + << "<" + << "NDim = " << NDim + << "MPerThread = " << MPerThread + << "AScalarPerVector = " << AScalarPerVector + << "BScalarPerVector = " << BScalarPerVector + << "CScalarPerVector = " << CScalarPerVector + << "DScalarPerVector = " << DScalarPerVector + << "EScalarPerVector = " << EScalarPerVector + << "FScalarPerVector = " << FScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp new file mode 100644 index 00000000..f41f65d7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/stream_config.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BaseArgument +{ + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; + BaseArgument& operator=(const BaseArgument&) = default; + + virtual ~BaseArgument() {} + + void* p_workspace_ = nullptr; +}; + +struct BaseInvoker +{ + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; + BaseInvoker& operator=(const BaseInvoker&) = default; + + virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) + { + return float{0}; + } + + virtual ~BaseInvoker() {} +}; + +struct BaseOperator +{ + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; + BaseOperator& operator=(const BaseOperator&) = default; + + virtual bool IsSupportedArgument(const BaseArgument*) { return false; } + virtual std::string GetTypeString() const { return ""; } + + virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + + virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + { + assert(p_arg); + p_arg->p_workspace_ = p_workspace; + } + + virtual ~BaseOperator() {} +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp new file mode 100644 index 00000000..e7559132 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp new file mode 100644 index 00000000..90c8f79d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp @@ -0,0 +1,48 @@ +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BatchedGemmCPermuteDesc +{ + ck::index_t G0_, G1_, M_, N_; + ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; +}; + +template +struct DeviceBatchedGemmCPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t BatchCount) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmCPermutePtr = std::unique_ptr< + DeviceBatchedGemmCPermute>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp new file mode 100644 index 00000000..fc65c811 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp @@ -0,0 +1,860 @@ +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + 0>{}, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto + MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) + { + const auto c_grid_desc_mraw_nraw = [&]() { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(stride_M, stride_N)); + }(); + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, + index_t G1, + index_t MRaw, + index_t NRaw, + index_t stride_G0, + index_t stride_G1, + index_t stride_M, + index_t stride_N) + { + const auto e_grid_desc_g0_g1_mraw_nraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(G0, G1, MRaw, NRaw), + make_tuple(stride_G0, stride_G1, stride_M, stride_N)); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_pass_through_transform(MRaw), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else + { + // not pad M or N + return e_grid_desc_g0_g1_mraw_nraw; + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1)); + using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, + index_t Batchstride_B, + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n) + : Batchstride_A_(Batchstride_A), + Batchstride_B_(Batchstride_B), + e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_A_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_B_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1); + index_t b0 = g_idx / G1; + index_t b1 = g_idx - b0 * G1; // g_idx % G1 + return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0)); + } + + private: + index_t Batchstride_A_; + index_t Batchstride_B_; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + }; + + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, // CShuffleDataType, + ck::Tuple<>, // DsDataType, + CDataType, // EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + NumPrefetch, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + BatchCount_(BatchCount), + a_grid_desc_k0_m_k1_{ + DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)}, + b_grid_desc_k0_n_k1_{ + DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)}, + c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N( + batched_gemm_c_permute_desc.M_, + batched_gemm_c_permute_desc.N_, + batched_gemm_c_permute_desc.stride_M_, + batched_gemm_c_permute_desc.stride_N_)}, + e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N( + batched_gemm_c_permute_desc.G0_, + batched_gemm_c_permute_desc.G1_, + batched_gemm_c_permute_desc.M_, + batched_gemm_c_permute_desc.N_, + batched_gemm_c_permute_desc.stride_G0_, + batched_gemm_c_permute_desc.stride_G1_, + batched_gemm_c_permute_desc.stride_M_, + batched_gemm_c_permute_desc.stride_N_)}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + compute_ptr_offset_of_batch_{ + type_convert(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), + type_convert(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), + e_grid_desc_g0_g1_m_n_}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + index_t BatchCount_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmCPermuteXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_c_permute_xdl< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.BatchCount_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + stride_A, + stride_B, + batched_gemm_c_permute_desc, + a_element_op, + b_element_op, + c_element_op, + BatchCount}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t BatchCount) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + stride_A, + stride_B, + batched_gemm_c_permute_desc, + a_element_op, + b_element_op, + c_element_op, + BatchCount); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmCPermuteXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 00000000..1486f0ac --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,1001 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; + }); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_reduces_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = compute_base_ptr_of_batch_; + ignore = block_2_ctile_map; +#endif +} + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideD) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideD_(BatchStrideD) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + template + __host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, + Number reduction_idx) const + { + // TODO - Support sequence of StrideD in MakeArgument() + (void)reduction_idx; + return g_idx * static_cast(BatchStrideD_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideD_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops, + index_t Batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + Batch_(Batch), + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + compute_base_ptr_of_batch_{ + type_convert(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), + type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), + type_convert(reduce_grid_desc_m_.GetElementSpaceSize())}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t Batch_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; + + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + auto casted_p_arg = dynamic_cast(p_arg); + if(casted_p_arg == nullptr) + { + return false; + } + else + { + return IsSupportedArgument(*casted_p_arg); + } + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + Batch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + Batch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp new file mode 100644 index 00000000..ee94290a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -0,0 +1,641 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +struct DeviceBatchedGemmXdl : public DeviceBatchedGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + const auto a_grid_desc_k0_mp_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_mp_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto b_grid_desc_k0_np_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_np_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto c_grid_desc_mp_np = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return c_grid_desc_mp_np; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + Batch_(Batch), + a_grid_desc_k0_m_k1_{ + DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, + b_grid_desc_k0_n_k1_{ + DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, + c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, + block_2_ctile_map_{ + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + index_t Batch_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.Batch_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.Batch_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp new file mode 100644 index 00000000..99be946e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor> +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto M = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(M, loop_step) - M; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(M, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& strides, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using GridwiseBinEltwise = GridwiseBinaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const std::vector& lengths, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + lengths_(lengths), + a_strides_(a_strides), + b_strides_(b_strides), + c_strides_(c_strides), + functor_(functor), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_); + b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_); + c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + CDataType* p_c_; + std::vector lengths_; + AGridDesc_M a_grid_desc_m_; + BGridDesc_M b_grid_desc_m_; + CGridDesc_M c_grid_desc_m_; + std::vector a_strides_; + std::vector b_strides_; + std::vector c_strides_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_binary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.a_grid_desc_m_, + arg.b_grid_desc_m_, + arg.c_grid_desc_m_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.size() != NDim) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = MPerThread % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector)) + return false; + + return true; + }; + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector> input_strides, + std::vector> output_strides, + ElementwiseFunctor functor) override + { + return std::make_unique(static_cast(p_inputs[0]), + static_cast(p_inputs[1]), + static_cast(p_outputs[0]), + lengths, + input_strides[0], + input_strides[1], + output_strides[0], + functor); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "NDim = " << NDim + << "MPerThread = " << MPerThread + << "AScalarPerVector = " << AScalarPerVector + << "BScalarPerVector = " << BScalarPerVector + << "CScalarPerVector = " << CScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp new file mode 100644 index 00000000..aedae538 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceCGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC) = 0; +}; + +template +using DeviceCGemmPtr = std::unique_ptr< + DeviceCGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp new file mode 100644 index 00000000..ac6b2347 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -0,0 +1,953 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ALayout, + typename BLayout, + typename CLayout, + typename ADataType, + typename BDataType, + typename CDataType, + typename GemmAccDataType, + typename CShuffleDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1, + index_t BK1, + index_t MPerXDL, + index_t NPerXDL, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceCGemm_4Gemm_Xdl_CShuffle + : public DeviceCGemm +{ + using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto MPerThread = Number<4>{}; + static constexpr auto AScalarPerVector = Number<4>{}; + static constexpr auto BScalarPerVector = Number<4>{}; + static constexpr auto CScalarPerVector = Number<4>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto M = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(M, loop_step) - M; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(M, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& strides, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid_real, + const ADataType* p_a_grid_imag, + const BDataType* p_b_grid_real, + const BDataType* p_b_grid_imag, + CDataType* p_c_grid_real, + CDataType* p_c_grid_imag, + CDataType* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_real_{p_a_grid_real}, + p_a_grid_imag_{p_a_grid_imag}, + p_b_grid_real_{p_b_grid_real}, + p_b_grid_imag_{p_b_grid_imag}, + p_c_grid_real_{p_c_grid_real}, + p_c_grid_imag_{p_c_grid_imag}, + p_aux_grid_{p_workspace}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + + const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + + if constexpr(is_same::value) + { + c_grid_desc_m_ = + DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); + } + else if constexpr(is_same::value) + { + c_grid_desc_m_ = + DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); + } + + p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); + } + + // private: + const ADataType* p_a_grid_real_; + const ADataType* p_a_grid_imag_; + const BDataType* p_b_grid_real_; + const BDataType* p_b_grid_imag_; + CDataType* p_c_grid_real_; + CDataType* p_c_grid_imag_; + CDataType* p_aux_grid_; + CDataType* p_aux_2_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M c_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + using Add = ck::tensor_operation::element_wise::Add; + using Subtract = ck::tensor_operation::element_wise::Subtract; + using GridwiseBinAdd = GridwiseBinaryElementwise_1D; + using GridwiseBinSubtract = GridwiseBinaryElementwise_1D; + const auto add_kernel = kernel_binary_elementwise_1d; + const auto subtract_kernel = kernel_binary_elementwise_1d; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_real_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_imag_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel(stream_config, + subtract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_real_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Subtract{}); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_imag_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_real_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel(stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_imag_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Add{}); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_real_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_imag_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel(stream_config, + subtract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_real_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Subtract{}); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_imag_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_real_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel(stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_imag_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Add{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a_real, + const ADataType* p_a_imag, + const BDataType* p_b_real, + const BDataType* p_b_imag, + CDataType* p_c_real, + CDataType* p_c_imag, + CDataType* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a_real, + p_a_imag, + p_b_real, + p_b_imag, + p_c_real, + p_c_imag, + p_workspace, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a_real), + static_cast(p_a_imag), + static_cast(p_b_real), + static_cast(p_b_imag), + static_cast(p_c_real), + static_cast(p_c_imag), + static_cast(p_workspace), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceCGemm_4Gemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } + + std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + [[maybe_unused]] index_t KRaw, + [[maybe_unused]] index_t StrideA, + [[maybe_unused]] index_t StrideB, + index_t StrideC) override + { + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); + + return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp new file mode 100644 index 00000000..fa0f07d3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::vector a_ms_ks_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 00000000..b130290f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,981 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD_Xdl_CShuffle + : public DeviceContractionMultipleD +{ + using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_ms_ks_lengths_vec, + const std::vector& a_ms_ks_strides_vec) + { + assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && + a_ms_ks_strides_vec.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0); + const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_ns_ks_lengths_vec, + const std::vector& b_ns_ks_strides_vec) + { + assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && + b_ns_ks_strides_vec.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0); + const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_vec, + const std::vector& e_ms_ns_strides_vec) + { + assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN && + e_ms_ns_strides_vec.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0); + const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(e_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + e_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + e_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return e_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = + decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector{}, std::vector{})); + using BGridDesc_BK0_N_BK1 = + decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector{}, std::vector{})); + using EGridDesc_M_N = + decltype(MakeEGridDescriptor_M_N(std::vector{}, std::vector{})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + std::vector a_ms_ns_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_ms_ns_lengths, a_ms_ks_strides)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_ns_ks_lengths, b_ns_ks_strides)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; + a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; + + b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; + b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!(arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::vector a_ms_ns_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_ms_ns_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::vector a_ms_ns_lengths, + std::vector a_ms_ks_strides, + std::vector b_ns_ks_lengths, + std::vector b_ns_ks_strides, + std::array, NumDTensor> ds_ms_ns_lengths, + std::array, NumDTensor> ds_ms_ns_strides, + std::vector e_ms_ns_lengths, + std::vector e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_ms_ns_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..31b2ca05 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,779 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight +{ + using DeviceOp = + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static constexpr auto N1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const index_t N0 = N / N1Number; + const index_t GemmK0Total = N0 * Ho * Wo; + + const index_t GemmK0S = + math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock; + const index_t GemmK0Pad = GemmKBatch * GemmK0S; + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K)); + + const auto out_n0_ho_wo_k_n1_grid_desc = + transform_tensor_descriptor(out_n_ho_wo_k_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0total_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)), + make_pass_through_transform(K), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0total_gemmm_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0pad_gemmm_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n0_y_ho_x_wo_c_n1_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Y), + make_pass_through_transform(Ho), + make_pass_through_transform(X), + make_pass_through_transform(Wo), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 6>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_n0_y_ho_x_wo_c_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C)), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0total_gemmn_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0pad_gemmn_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // unmerge N to N0 and N1, where N1 equals to K1 + if(!(arg.Conv_N_ % K1 == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..37ef8db3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,826 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdData +{ + using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t i_ytilde, + index_t i_xtilde) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + i_ytilde, + i_xtilde); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..5b880b1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,968 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivationAdd +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + } + + using GridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + p_c1_grid_{p_resi_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + p_resi_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + const void* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + static_cast(p_resi_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..bab98987 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,925 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + InMemoryDataOperationEnum OutGlobalMemoryDataOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivation +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + OutGlobalMemoryDataOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..e8b6a88f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,893 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXdl, + ck::index_t NPerXdl, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, // TODO: Add ShuffleType for DeviceConv2d + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock * K1, + K1, // AK1 + K1, // BK1 + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout + << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" + "nwavenperxdl_{ " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I0) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I1) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I2) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I3) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I4) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I5) + << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..84286320 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,733 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + const index_t GemmK = Y * X * C; + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000..f69d8f18 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP +#define DEVICE_CONV3D_FWD_NAIVE_HPP + +#include +#include +#include +#include "conv_util.hpp" +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "naive_conv_fwd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : params_{3, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + out_spatial_lengths_{output_spatial_lengths}, + p_in_{p_in}, + p_wei_{p_wei}, + p_out_{p_out}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + + { + } + + // private: + utils::conv::ConvParams params_; + std::vector out_spatial_lengths_; + + const InDataType* p_in_; + const WeiDataType* p_wei_; + OutDataType* p_out_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto naive_conv3d_fwd = + ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk; + + float ave_time = launch_and_time_kernel(stream_config, + naive_conv3d_fwd, + dim3(256), + dim3(256), + 0, + arg.p_in_, + arg.p_wei_, + arg.p_out_, + arg.N_, + arg.K_, + arg.C_, + arg.in_spatial_lengths_[0], + arg.in_spatial_lengths_[1], + arg.in_spatial_lengths_[2], + arg.filter_spatial_lengths_[0], + arg.filter_spatial_lengths_[1], + arg.filter_spatial_lengths_[2], + arg.out_spatial_lengths_[0], + arg.out_spatial_lengths_[1], + arg.out_spatial_lengths_[2], + arg.conv_filter_strides_[0], + arg.conv_filter_strides_[1], + arg.conv_filter_strides_[2], + arg.conv_filter_dilations_[0], + arg.conv_filter_dilations_[1], + arg.conv_filter_dilations_[2], + arg.in_left_pads_[0], + arg.in_left_pads_[1], + arg.in_left_pads_[2]); + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + std::vector out_spatial_lengths = arg.params_.GetOutputSpatialLengths(); + + bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] && + out_spatial_lengths[1] == arg.out_spatial_lengths_[1] && + out_spatial_lengths[2] == arg.out_spatial_lengths_[2]; + return out_lengths_are_consistent; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000..b48cfac0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,642 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef DEVICE_CONV3D_FWD_XDL_HPP +#define DEVICE_CONV3D_FWD_XDL_HPP + +#include +#include +#include +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "convolution_forward_specialization.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink. + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = num_batches; + ignore = a_batch_stride; + ignore = b_batch_stride; + ignore = c_batch_stride; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + /* + * \brief Split the number of batches, \p N, into N = B * N1, such that the memory + * space of input and output tensors stays with the value range of index_t, and each subbatch + * can be dealed with GridwiseGemm. + */ + static index_t GetMaxAllowableSubBatchSize(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector output_spatial_lengths) + { + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + // N1 should satisfy that + // 1) N % N1 = 0; + // 2) N1 * (Do * Ho * Wo * K) < (2^31 - 1) + // 3) N1 * (Di * Hi * Wi * C) < (2^31 - 1) + // + // Do NOT confuse (B, N1) in this function with (B, N1) in gridewise GEMM. + auto N1 = N + 1; + + const auto stride = + math::max(long_index_t(Do) * Ho * Wo * K, long_index_t(Di) * Hi * Wi * C); + const index_t max_stride = NumericLimits::Max(); + + for(index_t n0 = 1; n0 <= N; ++n0) + { + index_t n1 = N / n0; + if(n0 * n1 == N && long_index_t(n1) * long_index_t(stride) < max_stride) + { + N1 = n1; + break; + } + } + + const auto B = N / N1; + if(B * N1 != N) + { + throw std::runtime_error(__func__ + + std::string(": failed to find num_subbatches for conv3d.\n")); + } + + return N1; + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + assert(input_spatial_lengths.size() > 2); + assert(filter_spatial_lengths.size() > 2); + assert(conv_filter_strides.size() > 2); + assert(conv_filter_dilations.size() > 2); + assert(input_left_pads.size() > 2); + assert(input_right_pads.size() > 2); + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default, + "Wrong! This specialization not implemented!"); + + const auto in_desc_n_di_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto wei_desc_k_z_y_x_c = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto out_desc_n_do_ho_wo_k = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + + const auto descs = transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + in_desc_n_di_hi_wi_c, + wei_desc_k_z_y_x_c, + out_desc_n_do_ho_wo_k, + make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]), + make_tuple( + conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]), + make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]), + make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]), + Number{}); + + return descs; + } + + using ABCGridDescs = remove_cvref_t; + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + InDataType, + AccDataType, + OutDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t M01, + index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in}, + p_b_grid_{p_wei}, + p_c_grid_{p_out}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + const index_t subbatch_size = + GetMaxAllowableSubBatchSize(N, K, C, input_spatial_lengths, output_spatial_lengths); + num_subbatches_ = N / subbatch_size; + + const auto descs = + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(subbatch_size, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + b_batch_stride_ = 0; + c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize(); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const InDataType* p_a_grid_; + const WeiDataType* p_b_grid_; + OutDataType* p_c_grid_; + index_t num_subbatches_; + index_t a_batch_stride_; + index_t b_batch_stride_; + index_t c_batch_stride_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; + std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * + arg.num_subbatches_; + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp b/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp new file mode 100644 index 00000000..f1712025 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdWeight : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvBwdWeightPtr = std::unique_ptr< + DeviceConvBwdWeight>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp new file mode 100644 index 00000000..83c19703 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(void* p_in, + const void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvBwdDataPtr = std::unique_ptr< + DeviceConvBwdData>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp new file mode 100644 index 00000000..c30910ed --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp new file mode 100644 index 00000000..5a627dee --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivation : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp new file mode 100644 index 00000000..cc139303 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + const void* p_resi, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000..32d91269 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,1465 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight +{ + using DeviceOp = + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1); + } + + // type convert descs + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * 4; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + template + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using TypeConvertFp32ToBf16Functor = + ck::tensor_operation::element_wise::UnaryTypeConvert; + using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + float ave_time = 0; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + const auto run_conv = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + launch_and_time_kernel(StreamConfig{nullptr, false}, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + + return elapsed_time; + }; + + // run kernel for bf16 with splitk + const auto run_bf16_splitk = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_workspace_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(AccDataType))); + + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + static_cast(arg.p_workspace_), + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + + hipGetErrorString(hipMemset( + arg.p_workspace_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(AccDataType))); + + launch_and_time_kernel(StreamConfig{nullptr, false}, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + static_cast(arg.p_workspace_), + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + + return elapsed_time; + }; + + // kernel for type conversion + std::vector filter_dims{static_cast(arg.Conv_K_), + static_cast(arg.Conv_C_)}; + + filter_dims.insert(std::end(filter_dims), + std::begin(arg.filter_spatial_lengths_), + std::end(arg.filter_spatial_lengths_)); + + int tensor_size = + std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies{}); + + const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size); + GridDesc_M0 a_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + GridDesc_M0 b_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + + if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_)) + { + throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting"); + } + + // run kernel for type conversion + void* p_c_grid_tmp_ = static_cast(arg.p_c_grid_); + InDataType* p_c_grid_tmp_bf16_ = static_cast(p_c_grid_tmp_); + const auto run_type_convert = [&](const auto& kernel) { + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(type_convert_grid_size), + dim3(256), + 0, + static_cast(arg.p_workspace_), + p_c_grid_tmp_bf16_, + a_grid_desc_m0_, + b_grid_desc_m0_, + TypeConvertFp32ToBf16Functor{}); + return elapsed_time; + }; + + if constexpr(std::is_same::value) + { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + return run_conv(kernel); + } + else + { + const auto kernel_type_convert = + kernel_unary_elementwise_1d; + + const auto kernel_conv = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAddFloatBf16Splitk, + ADataType, // TODO: distiguish A/B datatype + AccDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + float elapsed_time = 0; + elapsed_time += run_bf16_splitk(kernel_conv); + elapsed_time += run_type_convert(kernel_type_convert); + return elapsed_time; + } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + } + else + { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + return run_conv(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + return run_conv(kernel); + } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NumDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0){ + + str << " Filter1x1Stride1Pad0"; + } + + // clang-format on + + return str.str(); + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = + arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] * + sizeof(float); + } + } + return WorkSpaceSize; + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final + { + return GetWorkSpaceSize(*dynamic_cast(p_arg)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 00000000..a5970c8f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,1551 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvBwdData +{ + using DeviceOp = DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_xtilde = tildes[0]; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + + const auto K0 = K / K1; + + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), + make_tuple(make_pass_through_transform(N * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto out_n_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); + const auto wei_k_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); + + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_n_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + const index_t i_ztilde = tildes[0]; + const index_t i_ytilde = tildes[1]; + const index_t i_xtilde = tildes[2]; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const auto K0 = K / K1; + + const auto out_n_do_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + const auto wei_k_z_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7, 8>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_k_z_y_x_c_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + CreateABCDesc(); + } + + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideW = conv_filter_strides_[0]; + const index_t ConvDilationW = conv_filter_dilations_[0]; + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t X = filter_spatial_lengths_[0]; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideD = conv_filter_strides_[0]; + const index_t ConvStrideH = conv_filter_strides_[1]; + const index_t ConvStrideW = conv_filter_strides_[2]; + + const index_t ConvDilationD = conv_filter_dilations_[0]; + const index_t ConvDilationH = conv_filter_dilations_[1]; + const index_t ConvDilationW = conv_filter_dilations_[2]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Z = filter_spatial_lengths_[0]; + const index_t Y = filter_spatial_lengths_[1]; + const index_t X = filter_spatial_lengths_[2]; + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + NumDimSpatial>(Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ztilde, i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( + descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NumDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ + + str<< " Filter1x1Stride1Pad0"; + } + + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp new file mode 100644 index 00000000..39ec1b49 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp @@ -0,0 +1,1085 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in NWC data format +// @li Weight tensor in KXC data format +// @li Output tensor in NWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceConvNdFwdNwcKxcNwk_Xdl + : public DeviceConvFwd>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using Base = + DeviceConvFwd>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation>; + + using DeviceOp = DeviceConvNdFwdNwcKxcNwk_Xdl; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = NumDimSpatial; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); + + // wei_gemmk0_gemmn_gemmk1_grid_desc + return transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto + GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n, ck::index_t gemm_m_pad) + { + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n)); + + // out_gemmm_gemmn_grid_desc + return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<2>{}, Sequence<0, 1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t X = filter_spatial_lengths[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + static index_t GetGemmMRaw(ck::index_t N, + const std::vector& output_spatial_lengths) + { + return N * std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies()); + } + + static index_t GetGemmK(ck::index_t C, const std::vector& filter_spatial_lengths) + { + return C * std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies()); + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths); + const index_t GemmN = K; + const index_t GemmK = GetGemmK(C, filter_spatial_lengths); + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + // C = A^T*B + // A: + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + GetInputTensorDescriptor(N, + C, + GemmMRaw, + GemmK, + GemmMPad, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + // B: + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK); + // C: + const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // Input tensors can't be bigger than 2GB each. + constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); + + if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 || + arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 || + arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2) + { + return false; + } + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(ck::index_t i = 0; i < NumDimSpatial; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(ck::index_t i = 0; i < NumDimSpatial; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvNdFwdNwcKxcNwk_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp new file mode 100644 index 00000000..f0946eb8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwise : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(std::array p_inputs, + std::array p_outputs, + std::vector lengths, + std::vector> input_strides, + std::vector> output_strides, + ElementwiseFunctor functor) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceElementwisePtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp new file mode 100644 index 00000000..bde29ba5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct GemmShape +{ + ck::index_t M, N, K; + ck::index_t StrideA, StrideB, StrideC; +}; + +template +struct DeviceGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceGroupedGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGroupedGemmPtr = std::unique_ptr< + DeviceGroupedGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp new file mode 100644 index 00000000..1aa38855 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -0,0 +1,875 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0)); + using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + BiasDataType, + D0DataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, + c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp new file mode 100644 index 00000000..bde0d48c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DEGridDesc_M0_M1_M2_N0_N1 +{ + ck::index_t M0_, M1_, M2_, N0_, N1_; + ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_; +}; + +// input : A[M, K], B[K, N], +// input : D[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D) +template +struct DeviceGemmBiasCPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasCPermutePtr = std::unique_ptr< + DeviceGemmBiasCPermute>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp new file mode 100644 index 00000000..f74cb0dc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_c_permute(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// input : A[M, K], or A[K, N] +// input : B[K, N], or A[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute +{ + using DeviceOp = DeviceGemmBiasCPermute_Xdl; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t NumDTensor = I1; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) + { + index_t M0 = d_e_grid_desc.M0_; + index_t M1 = d_e_grid_desc.M1_; + index_t M2 = d_e_grid_desc.M2_; + index_t N0 = d_e_grid_desc.N0_; + index_t N1 = d_e_grid_desc.N1_; + + index_t stride_M0 = d_e_grid_desc.stride_M0_; + index_t stride_M1 = d_e_grid_desc.stride_M1_; + index_t stride_M2 = d_e_grid_desc.stride_M2_; + index_t stride_N0 = d_e_grid_desc.stride_N0_; + index_t stride_N1 = d_e_grid_desc.stride_N1_; + + const auto MRaw = M0 * M1 * M2; + const auto NRaw = N0 * N1; + + const auto c_grid_desc_mraw_nraw = [&]() { + const auto c_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor( + make_tuple(M0, M1, M2, N0, N1), + make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); + + return transform_tensor_descriptor( + c_grid_desc_m0_m1_m2_n0_n1, + make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), + make_merge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + const void* p_d_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + + if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + p_ds_grid_(I0) = static_cast(p_d_grid); + + const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_bias_c_permute< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasCPermute_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp new file mode 100644 index 00000000..aca413f8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -0,0 +1,594 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDl : public DeviceGemm + +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused, but may be useful in future. + index_t M01_; + index_t N01_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmDl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp new file mode 100644 index 00000000..9c0594e3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmMultipleDPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 00000000..8c7f2c15 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,758 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + } + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp new file mode 100644 index 00000000..fcc088ca --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// FIXME: DeviceGemmReduce type need to well define the problem +template +struct DeviceGemmReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_ops, + std::array reduce_out_element_ops, + ck::index_t BatchCount = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 00000000..722ae113 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,835 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume Reduce is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp new file mode 100644 index 00000000..c701bff5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmSplitK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmSplitKPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp new file mode 100644 index 00000000..98028e1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -0,0 +1,547 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdl : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp new file mode 100644 index 00000000..9c8b189a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -0,0 +1,677 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffle : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp new file mode 100644 index 00000000..b82fcb67 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator +{ + using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + static auto MakeGridDescriptor_N(index_t NRaw) + { + const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw)); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad N + return transform_tensor_descriptor(grid_desc_nraw, + make_tuple(make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad N + return grid_desc_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + C0DataType, + ReduceAccDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + LoopSched>; + + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const C0DataType* p_c0_grid_add, + const C0DataType* p_c0_grid_bias, + const C0DataType* p_c0_grid_gamma, + const C0DataType* p_c0_grid_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_bias_{p_c0_grid_bias}, + p_c0_grid_add_{p_c0_grid_add}, + p_c0_grid_gamma_{p_c0_grid_gamma}, + p_c0_grid_beta_{p_c0_grid_beta}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_nblock_nperblock_{}, + block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_nblock_nperblock_ = + GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const C0DataType* p_c0_grid_bias_; + const C0DataType* p_c0_grid_add_; + const C0DataType* p_c0_grid_gamma_; + const C0DataType* p_c0_grid_beta_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_N c0_grid_desc_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const C0DataType* p_c0_bias, + const C0DataType* p_c0_add, + const C0DataType* p_c0_gamma, + const C0DataType* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0_bias, + p_c0_add, + p_c0_gamma, + p_c0_beta, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0_bias, + const void* p_c0_add, + const void* p_c0_gamma, + const void* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0_bias), + static_cast(p_c0_add), + static_cast(p_c0_gamma), + static_cast(p_c0_beta), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + } + + std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmLayerNorm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp new file mode 100644 index 00000000..306a73df --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -0,0 +1,646 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSplitK : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto + MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto + MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t k_batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + k_batch_{k_batch} + { + int KPad = DeviceGemmXdlSplitK::GetKPad(K, k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmXdlSplitK::MakeAGridDescriptor_KBatch_K0_M_K1( + M, K, StrideA, k_batch_, KPad); + b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmXdlSplitK::MakeBGridDescriptor_KBatch_K0_N_K1( + K, N, StrideB, k_batch_, KPad); + c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSplitK::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + // FIXME: this should be moved outside of DeviceOp + hipGetErrorString( + hipMemset(arg.p_c_grid_, + 0, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSplitK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp new file mode 100644 index 00000000..52bdacf7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp @@ -0,0 +1,650 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto + MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto + MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t k_batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + k_batch_{k_batch} + { + int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = + DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1( + M, K, StrideA, k_batch_, KPad); + b_grid_desc_kbatch_k0_n_k1_ = + DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1( + K, N, StrideB, k_batch_, KPad); + c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSplitKCShuffle::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp new file mode 100644 index 00000000..99979280 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -0,0 +1,652 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr, + gemm_desc_ptr[group_id].b_ptr, + gemm_desc_ptr[group_id].c_ptr, + p_shared, + gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_, + gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_, + gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedGemmXdl + : public DeviceGroupedGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch>; + + struct GroupedGemmBlock2CTileMap + { + using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + static_assert( + std::is_same::value, + "Wrong! Should be the same type name"); + GroupedGemmBlock2CTileMap() + { + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); + BlockStart_ = -1; + } + + GroupedGemmBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + ck::index_t BlockStart) + { + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + ck::index_t BlockStart_; + }; + + struct GemmDescKernelArg + { + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + + GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_; + + const ADataType* a_ptr; + const BDataType* b_ptr; + CDataType* c_ptr; + + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + grid_size_ = 0; + + p_workspace_ = nullptr; + + group_count_ = ck::type_convert(gemm_shapes.size()); + + if(!(group_count_ == ck::type_convert(p_a.size()) && + group_count_ == ck::type_convert(p_b.size()) && + group_count_ == ck::type_convert(p_c.size()))) + { + throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + const index_t M = gemm_shapes[i].M; + const index_t N = gemm_shapes[i].N; + const index_t K = gemm_shapes[i].K; + + const index_t StrideA = gemm_shapes[i].StrideA; + const index_t StrideB = gemm_shapes[i].StrideB; + const index_t StrideC = gemm_shapes[i].StrideC; + + const auto a_grid_desc_k0_m_k1_ = + DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + const auto b_grid_desc_k0_n_k1_ = + DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + const auto c_grid_desc_m_n_ = + DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + const index_t grid_size_grp = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0) + .block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + const auto grouped_gemm_block_2_ctile_map_ = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + grouped_gemm_block_2_ctile_map_)) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + gemm_desc_kernel_arg_.push_back( + GemmDescKernelArg{a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + grouped_gemm_block_2_ctile_map_, + static_cast(p_a[i]), + static_cast(p_b[i]), + static_cast(p_c[i]), + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t M01_; + index_t N01_; + index_t group_count_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGroupedGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.c_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + + if(!GridwiseGemm::CheckValidity( + arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) * + arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + hipGetErrorString( + hipMemcpy(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), + hipMemcpyHostToDevice)); + + float ave_time = 0; + + if(has_main_k_block_loop) + { + const auto kernel = + kernel_grouped_gemm_xdlops_v2r3; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdlops_v2r3; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + return false; + else + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique( + p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GemmDescKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp new file mode 100644 index 00000000..0e4313f1 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceNormalization : public BaseOperator +{ + // inLengths: input tensor extent(s) from high to low dimension + // inStrides: input tensor stride(s) from high to low dimension + // reduceDims: the dimension(s) the normalization operation is applied + // alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType + // beta: typeless pointer in host memory storing the beta scaling value of type AccDataType + // in_dev: typeless const pointer in device memory storing the input tensor + // out_dev: typeless pointer in device memory storing the output tensor + virtual std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const void* alpha, + const void* beta, + const void* in_dev, + void* out_dev) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual index_t GetRank() const = 0; + + virtual index_t GetNumReduceDim() const = 0; +}; + +using DeviceNormalizationPtr = std::unique_ptr; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp new file mode 100644 index 00000000..3b376c6f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* in_dev, + void* out_dev, + void* out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DevicePool2dFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp new file mode 100644 index 00000000..3edf9bd3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using IndexDataType = int32_t; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + static constexpr index_t InSrcOutDstVectorDim = + 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is + // not reduced. + + static constexpr ck::index_t ReduceM_BlockTileSize = + ReduceMThreadClusterSize * ReduceMThreadSliceSize; + static constexpr ck::index_t ReduceK_BlockTileSize = + ReduceKThreadClusterSize * ReduceKThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) + { + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = window_spatial_lengths[0]; + const index_t X = window_spatial_lengths[1]; + + const index_t ConvStrideH = window_strides[0]; + const index_t ConvStrideW = window_strides[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ReduceMRaw = N * Ho * Wo * C; + const index_t ReduceMPad = + math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw; + + const index_t ReduceKRaw = Y * X; + const index_t ReduceKPad = + math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw; + + // A[ReduceM, ReduceK] + const auto in_grid_desc_n_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_grid_desc_reducemraw_reducekraw = + transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), + make_merge_transform(make_tuple(Y, X))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad), + make_right_pad_transform(ReduceKRaw, ReduceKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const auto out_grid_desc_reducemraw = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C)); + + const auto out_grid_desc_reducem = transform_tensor_descriptor( + out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = decltype( + MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + // TODO + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + int* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array& input_spatial_lengths, + std::array& window_spatial_lengths, + std::array& output_spatial_lengths, + std::array& window_strides, + std::array& input_left_pads, + std::array& input_right_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + invariant_lowest_length_ = C; + reduce_lowest_length_ = window_spatial_lengths[1]; + + int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + int* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + ck::index_t invariant_lowest_length_; + ck::index_t reduce_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) + { + return (false); + } + + return (true); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) override + { + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ","; + str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ","; + str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp new file mode 100644 index 00000000..468d0b5a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceReducePtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp new file mode 100644 index 00000000..42e74f29 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_operator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those +// of reduce dims +template +std::pair get_2d_lengths(const std::vector& inLengths) +{ + static_assert(Rank <= 6, "bigger Rank size not supported!"); + + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; + + constexpr int NumInvariantDim = Rank - NumReduceDim; + + for(int i = NumInvariantDim; i < Rank; i++) + reduce_total_length *= inLengths[i]; + + for(int i = 0; i < NumInvariantDim; i++) + invariant_total_length *= inLengths[i]; + + return std::make_pair(invariant_total_length, reduce_total_length); +}; + +// helper functions using variadic template arguments +template +auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) +{ + return make_tuple(static_cast(lengths[Ns])...); +}; + +template +auto make_tuple_from_array(const std::vector& lengths, Number) +{ + static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); + + constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; + + return make_tuple_from_array_and_index_seq(lengths, index_seq); +}; + +template +std::vector shuffle_tensor_dimensions(const std::vector& origLengthsStrides, + const std::vector& reduceDims) +{ + std::vector newLengthsStrides; + + assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + // collect invariant dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + // collect reduce dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) > 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + return newLengthsStrides; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp new file mode 100644 index 00000000..a903fc41 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp @@ -0,0 +1,513 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceMultiBlock : public DeviceReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t numSrcDim = Rank; + static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); + + static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType::value, + "The OutDataType must support the specified OutMemoryDataOperation!"); + + static_assert(!use_multiblock || (use_multiblock && !OutputIndex), + "MultiBlock reduction can only be used when outputing index is not required"); + + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto MakeDst1dDescriptorForBufferSet(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const InDataType* in_dev, + const IndexDataType* in_index_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + in_index_dev_{in_index_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + if constexpr(use_multiblock) + { + + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + + gridSize_pre = + math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; + } + + std::vector inLengths_; + std::vector inStrides_; + std::vector outLengths_; + std::vector outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + const IndexDataType* in_index_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + size_t gridSize_pre; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m = + DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet( + arg.outLengths_, arg.outStrides_); + + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + using OutGridDesc_M_2 = decltype(out_grid_desc_m_2); + + using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock; + + const auto kernel_main = kernel_reduce_multiblock; + + float avg_time = 0; + + if constexpr(use_multiblock) + { + const auto identityVal = + ck::reduce::GetIdentityValueForInMemoryDataOperation( + OutMemoryDataOperation); + + const auto kernel_pre = + kernel_buffer_set_value; + + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + out_grid_desc_m_2, + arg.out_dev_, + identityVal); + }; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.in_index_dev_, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument* pArg) + { + if constexpr(use_multiblock) + { + if(static_cast(pArg->beta_) != 0.0f) + return (false); + }; + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); + + // This is very strong restriction, but needed to avoid some failure + if(pArg->invariant_lowest_length % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + // if(pArg->reduce_total_length / KThreadSliceSize < 2) + // return (false); + }; + + return (true); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(in_index_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp new file mode 100644 index 00000000..d9169549 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceThreadWise : public DeviceReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t numSrcDim = Rank; + static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const InDataType* in_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + } + + std::vector inLengths_; + std::vector inStrides_; + std::vector outLengths_; + std::vector outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = + DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); + const auto out_grid_desc_m = + DeviceReduceThreadWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + + float avg_time = 0; + + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + avg_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.alpha_, + arg.in_dev_, + nullptr, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + // cases with big reduce_total_length should be handled by Blockwise kernel + if(pArg->reduce_total_length / KThreadSliceSize >= 32) + return (false); + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, + const std::vector reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + (void)in_index_dev; + + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceThreadWise<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/include/ck/tensor_operation/gpu/device/device_softmax.hpp new file mode 100644 index 00000000..6a5dfc4d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmax : public DeviceNormalization +{ + static constexpr index_t kRank = Rank; + static constexpr index_t kNumReduceDim = NumReduceDim; + + virtual index_t GetRank() const override { return kRank; } + + virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } + + using PassThrough = tensor_operation::element_wise::PassThrough; + + // Used for freeloading of some handy functions from DeviceReduceMultiBlock + using Reduction = DeviceReduceMultiBlock; // OutDstVectorSize + + using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk; + + using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk; + + struct Argument : public Reduction::Argument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + AccDataType alpha, + AccDataType beta, + const InDataType* in_dev, + OutDataType* out_dev) + : Reduction::Argument(inLengths, + inStrides, + {}, + {}, + reduceDims, + 0.0f, // alpha + 0.0f, // beta + in_dev, + nullptr, + out_dev, + nullptr, + PassThrough{}, + PassThrough{}), + // FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of + // float32 precision. Make it support any data type so the fields can be removed. + alpha_(alpha), + beta_(beta) + { + // std::cout << "blkGroupSize= " << this->blkGroupSize + // << ", numBlockTileIteration= " << this->numBlockTileIteration + // << ", gridSize=" << this->gridSize + // << ", invariant_total_length=" << this->invariant_total_length << + // std::endl; + } + + AccDataType alpha_; + AccDataType beta_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + + bool sweep_once = + in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + const auto kernel_main = sweep_once ? kernel_softmax + : kernel_softmax; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m_k, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.beta_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if(!Reduction::IsSupportedArgument(p_arg_)) + { + return false; + } + + if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0) + { + return false; + } + + return true; + }; + + // inLengths: input tensor extent(s) from high to low dimension + // inStrides: input tensor stride(s) from high to low dimension + // reduceDims: the dimension(s) the softmax normalization operate on + // alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType + // beta: typeless pointer in host memory storing the beta scaling value as type AccDataType + // in_dev: typeless const pointer in device memory storing the input tensor + // out_dev: typeless pointer in device memory storing the output tensor + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const void* alpha, + const void* beta, + const void* in_dev, + void* out_dev) override + { + return std::make_unique(inLengths, + inStrides, + reduceDims, + *static_cast(alpha), + *static_cast(beta), + static_cast(in_dev), + static_cast(out_dev)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceSoftmax<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp new file mode 100644 index 00000000..05424542 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceUnaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + BDataType* p_b, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + shape_(shape), + functor_(functor), + blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future + { + index_t tensor_size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size); + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); + } + + const ADataType* p_a_; + BDataType* p_b_; + std::vector shape_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_unary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->shape_.back() % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + void* p_b, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + shape, + stride_a, + stride_b, + functor); + } + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp new file mode 100644 index 00000000..927a92e6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct GemmSpecialization +{ + Default, + MPadding, + NPadding, + KPadding, + MNPadding, + MKPadding, + NKPadding, + MNKPadding, +}; + +inline std::string getGemmSpecializationString(const GemmSpecialization& s) +{ + switch(s) + { + case GemmSpecialization::Default: return "Default"; + case GemmSpecialization::MPadding: return "MPadding"; + case GemmSpecialization::NPadding: return "NPadding"; + case GemmSpecialization::KPadding: return "KPadding"; + case GemmSpecialization::MNPadding: return "MNPadding"; + case GemmSpecialization::MKPadding: return "MKPadding"; + case GemmSpecialization::NKPadding: return "NKPadding"; + case GemmSpecialization::MNKPadding: return "MNKPadding"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp new file mode 100644 index 00000000..d3531835 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// FIXME: can it be replaced with ck::Tuple? +#include + +namespace ck { + +// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their +// respective functor classes. +// The boolean member "indexable" are also provided in reduce_binary_operactor for +// easier checking by the upper-layer codes in the kernels. + +template +struct reduce_binary_operator; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Mul; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Min; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Max; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::AMax; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary +// functor classes. +// The two unary functors are called before and afer the Reduction is executed respectively +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp new file mode 100644 index 00000000..40c7eb7d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +namespace gemm { + +struct RowMajor : public BaseTensorLayout +{ + static constexpr const char* name = "RowMajor"; +}; + +struct ColumnMajor : public BaseTensorLayout +{ + static constexpr const char* name = "ColumnMajor"; +}; +} // namespace gemm + +namespace convolution { + +// 1D Conv +struct NWC : public BaseTensorLayout +{ + static constexpr const char* name = "NWC"; +}; + +struct KXC : public BaseTensorLayout +{ + static constexpr const char* name = "KXC"; +}; + +struct NWK : public BaseTensorLayout +{ + static constexpr const char* name = "NWK"; +}; + +struct NCW : public BaseTensorLayout +{ + static constexpr const char* name = "NCW"; +}; + +struct KCX : public BaseTensorLayout +{ + static constexpr const char* name = "KCX"; +}; + +struct NKW : public BaseTensorLayout +{ + static constexpr const char* name = "NKW"; +}; + +// 2D Conv +struct NHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWC"; +}; + +struct KYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXC"; +}; + +struct NHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWK"; +}; + +struct NCHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCHW"; +}; + +struct KCYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCYX"; +}; + +struct NKHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKHW"; +}; + +// 3D Conv +struct NDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWC"; +}; + +struct KZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXC"; +}; + +struct NDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWK"; +}; +struct NCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCDHW"; +}; + +struct KCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCZYX"; +}; + +struct NKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKDHW"; +}; + +} // namespace convolution + +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} + +} // namespace tensor_layout +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp new file mode 100644 index 00000000..ece1ecb8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } +}; + +struct Subtract +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp - x2_tmp; + y = ck::type_convert(y_tmp); + } +}; + +struct Bilinear +{ + Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = alpha_ * x0 + beta_ * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); + }; + + float alpha_; + float beta_; +}; + +struct AddRelu +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > 0.0 ? a : 0.0; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; +}; + +struct AddHardswish +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + double a = x0 + x1; + double b = a + 3.0; + double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + float a = x0 + x1; + float b = a + 3.0f; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp new file mode 100644 index 00000000..9c273e75 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// Need to ensure compiler will fail if there is no matching candidate, instead of compiler +// siliently do implicit type conversion +// +// Method 1: +// +// struct ExampleElementwiseOp +// { +// template +// __host__ __device__ constexpr void +// operator()(Y&, const X) const; +// +// template<> +// __host__ __device__ constexpr void +// operator()(half_t& y, const half_t& x) const +// { +// } +// }; +// +// Method 2: +// +// template +// struct ExampleElementwiseOp; +// +// template <> +// struct ExampleElementwiseOp +// { +// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const +// { +// } +// }; + +struct AddReluAdd +{ + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + half_t a = x0 + x1; + half_t b = a > 0 ? a : 0; + y = b + x2; + } + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } +}; + +struct AddHardswishAdd +{ + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } +}; + +// C = A * B +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu +{ + template + __host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + const auto fast_gelu = [&](float x) { + const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float emu = exp(-u); + const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + return x * cdf; + }; + + const float y = fast_gelu(c + float(d0) + float(d1)); + + e = type_convert(y); + } +}; + +struct Normalize +{ + // FIXME: is double absolutely necessary? + Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {} + + template + __host__ __device__ constexpr void operator()( + T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x, + const float& mean, + const float& mean_square, + const float& gamma, + const float& beta) const + { + using ck::math::sqrt; + + float variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + type_convert(epsilon_))) * gamma + beta; + }; + + template <> + __host__ __device__ constexpr void operator()(double& y, + const double& x, + const double& mean, + const double& mean_square, + const double& gamma, + const double& beta) const + { + using ck::math::sqrt; + + double variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta; + }; + + // FIXME: is double absolutely necessary? + double epsilon_; +}; + +template +struct UnaryTypeConvert; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const + { + y = ck::type_convert(x); + } +}; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const + { + y = ck::type_convert(x); + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp new file mode 100644 index 00000000..440f0de4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct PassThrough +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x; + }; +}; + +struct Scale +{ + __host__ __device__ Scale(float scale) : scale_(scale) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = scale_ * x; + }; + + float scale_; +}; + +struct UnaryDivide +{ + __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {} + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +struct UnarySquare +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x * x; + }; +}; + +struct UnaryAbs +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::abs(x); + }; +}; + +struct UnarySqrt +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sqrt(x); + }; +}; + +struct Relu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + y = x > 0 ? x : 0; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_f32 = ck::type_convert(x); + float y_f32 = x_f32 > 0 ? x_f32 : 0; + y = ck::type_convert(y_f32); + } +}; + +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +struct FastGelu +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float emu = exp(-u); + const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + + y = x * cdf; + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp new file mode 100644 index 00000000..498a88af --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -0,0 +1,489 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { + +// Rows of column-vectors +template +struct BlockToCTileMap_M00_N0_M01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) + : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + + const index_t grid_size = M00 * M01_ * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + if(M0 % M01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + + const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), + make_unmerge_transform(make_tuple(M00, M01)), + make_pass_through_transform(make_tuple(N0))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + + const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_n0_m01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1)); + UnderlyingMap underlying_map_; +}; + +// Rows of column-vectors +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// 2D slices of column-vectors in 3D space +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8, + index_t KSplit = 1) + : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0 * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + index_t KSplit_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// Blocks of row-vectors +template +struct BlockToCTileMap_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1) + : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1)); + UnderlyingMap underlying_map_; +}; + +// 2D slices of row-vectors in 3D space +template +struct BlockToCTileMap_KSplit_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1, + index_t KSplit = 1) + : M01_(M01), + N01_(N01), + KSplit_(KSplit), + underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + index_t KSplit) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KSplit), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_, KSplit_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); + UnderlyingMap underlying_map_; +}; + +template +__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) +{ + bool is_valid = false; + + const index_t m_block = c_tile_dim[Number<0>{}]; + const index_t n_block = c_tile_dim[Number<1>{}]; + + if constexpr(CTileIdx::Size() == 2) + { + const index_t m_block_idx = c_tile_idx[Number<0>{}]; + const index_t n_block_idx = c_tile_idx[Number<1>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + } + else if constexpr(CTileIdx::Size() == 3) + { + const index_t ksplit_idx = c_tile_idx[Number<0>{}]; + const index_t m_block_idx = c_tile_idx[Number<1>{}]; + const index_t n_block_idx = c_tile_idx[Number<2>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + ignore = ksplit_idx; + } + + return is_valid; +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp new file mode 100644 index 00000000..6836a660 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -0,0 +1,613 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + (void)p_in_index_global; + (void)p_out_index_global; + + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(block_group_size == 0 && !float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + false>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + } + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using BlockwiseReduceWithIndex = + PartitionedBlockwiseReductionWithIndex, + ThreadClusterArrangeOrder, + ReduceOperation, + PropagateNan>; + + using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; + + (void)in_elementwise_op; + + // LDS + __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; + __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_val_buf = + make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); + auto reduce_work_idx_buf = + make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_1d_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + index_t indexOffset = 0; + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + // initialize the indices for the per-thread to-reduce values + in_thread_idx_buf(Number{}) = + indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); + + // do element-wise pre-reduction operation + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexOffset += K_BlockTileSize; + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + }; + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + // for indiced operation, acc_elementwise_op shoud do nothing + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + threadwise_dst_idx_store.Run(reduced_data_desc, + make_tuple(I0), + accu_index_buf, + out_grid_desc_m, + out_global_idx_buf); + } + }; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp new file mode 100644 index 00000000..6c5bd29f --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -0,0 +1,474 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_threadwise +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + using ThreadwiseReduce = ThreadwiseReduction; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto dst_global_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + true>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + dst_global_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex; + + (void)acc_elementwise_op; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t indexStart = 0; + index_t reducedLength = 0; + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + } + else + { + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_thread_idx_buf(Number{}) = indexStart + iK(); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + }; + + // for indiced operation, acc_elementwise_op shoud do nothing + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + false>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf); + + threadwise_dst_idx_store.Run( + reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf); + }; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp new file mode 100644 index 00000000..23937348 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + const CDataType* __restrict__ p_c_global, + const DDataType* __restrict__ p_d_global, + const EDataType* __restrict__ p_e_global, + FDataType* __restrict__ p_f_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const DGridDesc_M d_grid_desc_m, + const EGridDesc_M e_grid_desc_m, + const FGridDesc_M f_grid_desc_m, + const ElementwiseFunctor functor) +{ + Gridwise5AryEltwise::Run(p_a_global, + p_b_global, + p_c_global, + p_d_global, + p_e_global, + p_f_global, + a_grid_desc_m, + b_grid_desc_m, + c_grid_desc_m, + d_grid_desc_m, + e_grid_desc_m, + f_grid_desc_m, + functor); +} + +// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs +template +struct Gridwise5AryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * MPerThread); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + const CDataType* __restrict__ p_c_global, + const DDataType* __restrict__ p_d_global, + const EDataType* __restrict__ p_e_global, + FDataType* __restrict__ p_f_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const DGridDesc_M d_grid_desc_m, + const EGridDesc_M e_grid_desc_m, + const FGridDesc_M f_grid_desc_m, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m.GetElementSpaceSize()); + const auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m.GetElementSpaceSize()); + const auto d_global_buf = make_dynamic_buffer( + p_d_global, d_grid_desc_m.GetElementSpaceSize()); + const auto e_global_buf = make_dynamic_buffer( + p_e_global, e_grid_desc_m.GetElementSpaceSize()); + auto f_global_buf = make_dynamic_buffer( + p_f_global, f_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + StaticBuffer d_thread_buf; + StaticBuffer e_thread_buf; + StaticBuffer f_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + AScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m, thread_store_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + BScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m, thread_store_global_offset}; + + auto c_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + CScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{c_grid_desc_m, thread_store_global_offset}; + + auto d_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + DScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{d_grid_desc_m, thread_store_global_offset}; + + auto e_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + EScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{e_grid_desc_m, thread_store_global_offset}; + + auto f_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + FScalarPerVector, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + f_grid_desc_m, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = c_grid_desc_m.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = M / (loop_step); + do + { + // read and process MPerThread elements + a_global_load.Run( + a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf); + + c_global_load.Run( + c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf); + + d_global_load.Run( + d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf); + + e_global_load.Run( + e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf); + + static_for<0, MPerThread, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m)); + functor(f_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{}), + c_thread_buf(Number{}), + d_thread_buf(Number{}), + e_thread_buf(Number{})); + }); + + f_global_write.Run(thread_desc_m, + make_tuple(I0), // SrcSliceOriginIdx + f_thread_buf, + f_grid_desc_m, + f_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index); + c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index); + d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index); + e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index); + f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp new file mode 100644 index 00000000..d4e7d142 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const ElementwiseFunctor functor) +{ + GridwiseBinEltwise::Run( + p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor); +} + +template +struct GridwiseBinaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * MPerThread); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + AScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m, thread_store_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + BScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m, thread_store_global_offset}; + + auto c_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + CScalarPerVector, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + c_grid_desc_m, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = c_grid_desc_m.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = M / (loop_step); + do + { + // read and process MPerThread elements + a_global_load.Run( + a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf); + + static_for<0, MPerThread, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m)); + functor(c_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{})); + }); + + c_global_write.Run(thread_desc_m, + make_tuple(I0), // SrcSliceOriginIdx + c_thread_buf, + c_grid_desc_m, + c_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index); + c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp similarity index 98% rename from composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp index fe56d0d8..2369f517 100644 --- a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP #define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP @@ -55,7 +58,7 @@ template , integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); @@ -381,9 +384,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN "wrong!"); // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, @@ -405,9 +408,9 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN make_multi_index(0, 0, 0, 0, 0)); // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, - InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum::Set, Sequence, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, @@ -467,7 +470,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); ThreadwiseTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto a_block_odd_buf = make_dynamic_buffer( + auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( + auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000..22d96a10 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,995 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_bias_grid, + p_d0_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_bias_grid; + ignore = p_d0_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c1_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const C1ElementwiseOperation& c1_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + // c0 and c1 + constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + + auto c01_thread_buf = make_static_buffer( + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatReduceAcc, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC1, + FloatReduceAcc, + decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatC, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_buf, + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + c1_thread_copy_global_to_vgpr.Run( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_buf, + c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c1_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C1 + c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } // Reduction + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp new file mode 100644 index 00000000..ed98b626 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -0,0 +1,576 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDl_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) + { + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_grid_desc_k0_m0_m1_k1; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_grid_desc_k0_n0_n1_k1; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_m0_m10_m11_n0_n10_n11; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); + } + + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_k0_m0_m1_k1, + make_multi_index(0, im0, 0, 0), + a_block_desc_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_k0_n0_n1_k1, + make_multi_index(0, in0, 0, 0), + b_block_desc_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp similarity index 88% rename from composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp index d91159b8..84e033e1 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP #define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP @@ -12,7 +15,6 @@ namespace ck { -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE template {}, - integral_constant{}); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by CONSTANT void pointer -// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k_m0_m1_grid_desc, - const void CONSTANT* p_b_k_n0_n1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) + const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) { - // first cast void CONSTANT void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k_m0_m1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc)); - const auto b_k_n0_n1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc)); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); - constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -102,17 +49,16 @@ __global__ void a_k_m0_m1_grid_desc, b_k_n0_n1_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, + cblockid_to_m0_n0_block_cluster_adaptor, integral_constant{}, integral_constant{}); } -#endif template {}), make_tuple(Sequence<0>{})); - return c_blockid_to_m0_n0_block_cluster_adaptor; + return cblockid_to_m0_n0_block_cluster_adaptor; } using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); @@ -321,22 +267,22 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 const AKM0M1GridDesc& a_k_m0_m1_grid_desc, const BKN0N1GridDesc& b_k_n0_n1_grid_desc, const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, integral_constant, integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); const auto K = a_k_m0_m1_grid_desc.GetLength(I0); // divide block work by [M, N] const auto c_m0_n0_block_cluster_idx = - c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( make_multi_index(get_block_1d_id())); // HACK: this force index data into SGPR @@ -372,7 +318,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1, @@ -398,7 +344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 // B matrix blockwise copy auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4, BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1, @@ -460,7 +406,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); ThreadwiseTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); - auto a_block_odd_buf = make_dynamic_buffer( + auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( + auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, b_k_n0_n1_block_desc.GetElementSpaceSize()); diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp similarity index 98% rename from composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp index 84ee6f40..b1dfb0c7 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_GRIDWISE_GEMM_V2_HPP #define CK_GRIDWISE_GEMM_V2_HPP @@ -15,7 +18,7 @@ template {}; constexpr auto I3 = Number<3>{}; - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_global, a_e_k_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( + auto c_global_buf = make_dynamic_buffer( p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); constexpr auto E = EPerBlock * 3 * 3; @@ -181,7 +184,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadClusterLengths_E_K, @@ -221,11 +224,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 b_e_n_ho_wo_global_desc, make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); - auto a_block_buf = make_dynamic_buffer( + auto a_block_buf = make_dynamic_buffer( p_shared_block, a_e_k_desc.GetElementSpaceSize()); // register allocation for output - StaticBuffer @@ -250,7 +253,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 BGlobalMoveSliceWindowStepHacks{}; // double regsiter buffer for b - StaticBuffer diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp new file mode 100644 index 00000000..ace84433 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp @@ -0,0 +1,1597 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_GRIDWISE_GEMM_V3_HPP +#define CK_GRIDWISE_GEMM_V3_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" +#include "blockwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActiv(p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3_resize_add( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_d_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid, + p_b_grid, + p_bias_grid, + p_d_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3_maxpool( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + FloatC* __restrict__ p_d_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActivMaxpool(p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_d_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDlops_km_kn_mn_v3 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto E1 = Number{}; + static constexpr auto E2 = Number{}; + static constexpr auto K2 = Number{}; + + static constexpr auto NPerBlock = I1; + + static constexpr FloatAcc alpha = 0.3; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(I1, Number{}, Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + + const auto K0 = K / KPerBlock; + const auto N0 = N / NPerBlock; + const auto H0 = Ho / HoPerBlock; + const auto W0 = Wo / WoPerBlock; + + const index_t grid_size = K0 * N0 * H0 * W0; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0) + { + const bool has_main_e0_block_loop = E0 > 1; + + return has_main_e0_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop() + { + const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1; + + return has_main_e1_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop() + { + const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0; + + return has_double_tail_e1_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc) + { + const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0); + const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor( + a_e0_e1_k_e2_grid_desc, + make_tuple(make_pass_through_transform(E0), + make_pass_through_transform(E1), + make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_e0_e1_k0_k1_e2_grid_desc; + } + + __host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor( + const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc) + { + const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); + // const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1); + const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2); + const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3); + const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); + // const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5); + + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Ho / (H1 * H2); + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Wo / (W1 * W2); + + const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = + transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc, + make_tuple(make_pass_through_transform(E0), + make_pass_through_transform(E1), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3, 4, 5>{}, + Sequence<6, 7, 8>{}, + Sequence<9>{})); + + return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Ho / (H1 * H2); + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Wo / (W1 * W2); + + const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor( + c_k_n_ho_wo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) + { + const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); + const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); + const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); + const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Number{}; + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Number{}; +#else + const auto H2 = HoPerThread / 2; + const auto H1 = HoPerBlock / HoPerThread; + const auto H0 = Hx / (H1 * H2); + + const auto W2 = WoPerThread / 2; + const auto W1 = WoPerBlock / WoPerThread; + const auto W0 = Wx / (W1 * W2); +#endif + + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( + d_k_n_hx_wx_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) + { + const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); + const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); + const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); + const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto H2 = Number{}; + const auto H1 = Number{}; + + const auto W2 = Number{}; + const auto W1 = Number{}; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto H0 = Number{}; + const auto W0 = Number{}; +#else + const auto H0 = Hx / (H1 * H2); + const auto W0 = Wx / (W1 * W2); +#endif + + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( + d_k_n_hx_wx_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto K0 = Number{}; + const auto N0 = Number{}; + const auto H0 = Number{}; + const auto W0 = Number{}; +#else + const auto K0 = K / KPerBlock; + const auto N0 = N / NPerBlock; + const auto H0 = Ho / HoPerBlock; + const auto W0 = Wo / WoPerBlock; +#endif + + const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + return cblockid_to_k_n_ho_wo_block_cluster_adaptor; + } + + // using AGridDesc_E0_E1_K0_K1_E2 = + // decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); + // using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = + // decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); + // using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = + // decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); + // using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = + // decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{})); + + using CBlockIdToBlockClusterAdaptor_K_N_H_W = + decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); + + template + __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) + { + const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0); + const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1); + + return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); + } + + __host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor() + { + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + return c_k1_n_h2_w2_thread_gemm_desc; + } + + // using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor()); + + __host__ __device__ static constexpr auto GetBlockWiseGemm() + { + constexpr auto max_lds_align = Number{}; + + constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + constexpr auto b_e1_n_h_w_e2_block_gemm_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + + return blockwise_gemm; + } + + __device__ static constexpr auto GetCThreadIndex() + { + auto blockwise_gemm = GetBlockWiseGemm(); + auto c_thread_mtx_index = + blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); + + return c_thread_mtx_index; + }; + + __device__ static constexpr auto GetCBlockIndex( + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor) + { + const auto c_k_n_h_w_block_cluster_idx = + cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + return c_k_n_h_w_block_cluster_idx; + } + + template + __device__ static void BiasOp(BiasGlobalBuff& bias_global_buf, + CThreadBuff& c_thread_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc, + const CThreadDesc_K1_N_H2_W2&) + + { + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + + const auto k_thread_id = c_thread_idx[I0]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + constexpr auto bias_k0_k1_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + StaticBuffer + bias_thread_buf; + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + auto bias_threadwise_transfer = + ThreadwiseTensorSliceTransfer_v2{}>, + Sequence<0, 1>, + 1, + CThreadTransferDstScalarPerVector, + false, + true>( + bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global)); + + constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple( + make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{})); + + bias_threadwise_transfer.Run(bias_k0_k1_grid_desc, + bias_global_buf, + bias_k0_k1_thread_desc, + make_tuple(I0, I0), + bias_thread_buf, + bias_k0_k1_global_tensor_step_hacks); + + static_for<0, KPerThread, 1>{}([&](auto ki) { + static_for<0, HoPerThread, 1>{}([&](auto hi) { + static_for<0, WoPerThread, 1>{}([&](auto wi) { + constexpr index_t c_offset = + c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi)); + c_thread_buf(Number{}) = + c_thread_buf[Number{}] + bias_thread_buf[ki]; + }); + }); + }); + } + + template + __device__ static void Activation(CThreadBuff& c_thread_buf, + const CThreadDesc_K1_N_H2_W2&, + integral_constant) + { + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { + if constexpr(activ_type_ == 1) + { + c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; + } + else if constexpr(activ_type_ == 2) + { + FloatAcc x = 1.0 + exp(-c_thread_buf[i]); + + asm volatile("\n \ + v_rcp_f32 %0, %1 \n" + : "=v"(x) + : "0"(x)); + + c_thread_buf(i) = x; + } + }); + } + + template + __device__ static void + WriteOut(const CThreadBuff& c_thread_buf, + CGlobalBuff& c_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) + { + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + // hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global + // tensor + constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{}; + + constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + c_global_buf, + c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks); + } + + template + __device__ static void + MaxPool(const CThreadBuff& c_thread_buf, + DGlobalBuff& d_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CThreadDesc_K1_N_H2_W2&, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) + { + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, ""); + + constexpr auto HoPerThread_2 = HoPerThread / 2; + constexpr auto WoPerThread_2 = WoPerThread / 2; + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + StaticBuffer + d_thread_buf; + + static_for<0, KPerThread, 1>{}([&](auto ki) { + static_for<0, HoPerThread_2, 1>{}([&](auto hi) { + static_for<0, WoPerThread_2, 1>{}([&](auto wi) { + constexpr index_t d_offset = + d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset( + make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi)); + + constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2, wi * 2)); + constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2, wi * 2 + 1)); + constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2 + 1, wi * 2)); + constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1)); + + d_thread_buf(Number{}) = c_thread_buf[Number{}]; + d_thread_buf(Number{}) = + fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); + d_thread_buf(Number{}) = + fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); + d_thread_buf(Number{}) = + fmax(c_thread_buf[Number{}], d_thread_buf(Number{})); + }); + }); + }); + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, + true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d_thread_buf, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + d_global_buf, + d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); + } + + template + __device__ static void + ResizeAdd(const CThreadBuff& c_thread_buf, + DGlobalBuff& d_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CThreadDesc_K1_N_H2_W2&, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) + { + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + constexpr auto HoPerThreadx2 = HoPerThread * 2; + constexpr auto WoPerThreadx2 = WoPerThread * 2; + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + StaticBuffer + d_thread_buf; + + static_for<0, KPerThread, 1>{}([&](auto k_i) { + static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { + static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { + d_thread_buf(Number{}) = + c_thread_buf[Number{}]; + }); + }); + }); + + // hack to control index calculation when iterating over d_k_n_ho_wo_global tensor + constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + InMemoryDataOperationEnum::Add, + 1, + true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d_thread_buf, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + d_global_buf, + d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); + } + + template + __device__ static void + GemmOp(const AGlobalBuff& a_global_buf, + const BGlobalBuff& b_global_buf, + CThreadBuff& c_thread_buf, + FloatAB* __restrict__ p_shared_block, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CThreadDesc_K1_N_H2_W2&, + integral_constant) + { + constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); + constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); + + // const auto c_k_n_h_w_block_cluster_idx = + // GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + // cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( + // make_multi_index(get_block_1d_id())); + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + constexpr auto max_lds_align = Number{}; + + constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + constexpr auto b_e1_n_h_w_e2_block_gemm_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + // blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); + + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, I1, Number{}, Number{}), + max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e0_e1_k0_k1_e2_grid_desc), + decltype(a_e0_e1_k0_k1_e2_block_copy_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorDim, + 4, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_E2, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + false>(a_e0_e1_k0_k1_e2_grid_desc, + make_multi_index(0, 0, k_block_work_id, 0, 0), + a_e0_e1_k0_k1_e2_block_copy_desc, + make_multi_index(0, 0, 0, 0, 0)); + + constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); + + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{}, + Number{})); + + auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2< + FloatAB, + FloatAB, + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc), + Sequence, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + make_multi_index(0, + 0, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0, + 0)); + + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); + + //// register allocation for output + // StaticBuffer + // c_thread_buf; + + // initialize output thread tensor + ThreadwiseTensorSliceSet_v1>{} + .Run(c_k1_n_h2_w2_thread_gemm_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto b_thread_slice_copy_step = + make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; + + // double regsiter buffer for b + StaticBuffer + b_thread_even_buf, b_thread_odd_buf; + + if constexpr(HasMainE0BlockLoop) + { + const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0); + + index_t e0_block_data_begin = 0; + + do + { + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead( + a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainE1BlockLoop) + { + index_t e1_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + e1_block_data_begin += 2 * E1PerBlock; + + } while(e1_block_data_begin < E1 - 2 * E1PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + + a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc, + a_block_slice_copy_step, + AGlobalMoveSliceWindowStepHacks{}); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + e0_block_data_begin += 1; + + } while(e0_block_data_begin < E0); + } + else + { + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead( + a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainE1BlockLoop) + { + index_t e1_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + e1_block_data_begin += 2 * E1PerBlock; + + } while(e1_block_data_begin < E1 - 2 * E1PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + } + } + + template + __device__ static void + Conv(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant) + { + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + } + + template + __device__ static void ConvBiasActiv( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + } + + template + __device__ static void ConvBiasActivMaxpool( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + // MaxPool + MaxPool(c_thread_buf, + d_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k1_n_h2_w2_thread_gemm_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + } + + template + __device__ static void ConvBiasActivResizeAdd( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Resize_Add + ResizeAdd(c_thread_buf, + d_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k1_n_h2_w2_thread_gemm_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + } +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 00000000..5ce7db0a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const StaticallyIndexedArray& + ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different + // type from E + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp new file mode 100644 index 00000000..42a56e2a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" + +namespace ck { + +template +struct GridwiseGemmPipeline_v1; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<1> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + // Write num_loop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +template +struct GridwiseGemmPipelineInterwave_v1; + +template <> +struct GridwiseGemmPipelineInterwave_v1<1> +{ + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // block_sync_lds(); // moved into blockwise_gemm + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// Note: 2 stage prefetch not optimized for inter-wave loop scheduler +template <> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +{ +}; + +template +constexpr auto GridwiseGemmPipeline_v1_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp new file mode 100644 index 00000000..3281b910 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v2 +{ + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write 0 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global Read 1 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // move to i + 2 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write i + 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global read i + 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write i + 1 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global read i + 2 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + ++i; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // LDS write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + block_sync_lds(); + + // GEMM num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000..8e29b518 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // TODO - extract following into reduction_blockwise + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + + // Reduction + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 00000000..3fa6c10e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,658 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = +#if 1 + remove_cvref_t())>; +#else + GridwiseGemmPipeline_v2; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp new file mode 100644 index 00000000..1b8286cf --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // TODO ANT: separate into MMA + Epilogue + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_bias_grid, + p_c0_add_grid, + p_c0_gamma_grid, + p_c0_beta_grid, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + block_2_ctile_map); + + // TODO ANT: Run layernorm epilogue here +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_bias_grid; + ignore = p_c0_add_grid; + ignore = p_c0_gamma_grid; + ignore = p_c0_beta_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +template +struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + // LDS allocation for reduction workspace + constexpr index_t c_lds_workspace_size = BlockSize; + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size_aligned * sizeof(FloatCShuffle) + + c_lds_workspace_size * sizeof(FloatReduceAcc)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // in order to reduce N dim without elaborate sync across CUs in single kernel, one + // workgroup must span the entire N extent + if(math::integer_divide_ceil(N, NPerBlock) > 1) + { + return false; + } + + // static check: all waves in the workgroups combined must cover whole N extent in order + // to have efficient N-dim reduction + static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, + "condition not met for efficient layernorm"); + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // for bias, beta, gamma + __host__ __device__ static constexpr auto + MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n) + { + const auto N = c0_grid_desc_n.GetLength(I0); + const auto NBlock = N / NPerBlock; + + const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_n, + make_tuple(make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return c0_grid_desc_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_bias_grid_buf = make_dynamic_buffer( + p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_grid_buf = make_dynamic_buffer( + p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_gamma_grid_buf = make_dynamic_buffer( + p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + auto c0_beta_grid_buf = make_dynamic_buffer( + p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0); + + // for broadcasting bias, beta, gamma + const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_nblock_nperblock, + make_tuple(make_insert_transform(I1), + make_insert_transform(I1), + make_pass_through_transform(NBlock), + make_pass_through_transform(NPerBlock)), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // pytorch default + // https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html + static constexpr FloatReduceAcc epsilon = 1e-5; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // TODO: this should be implemented as a blockwise reduction + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + auto d_reduce_work_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + c_block_size_aligned), + BlockSize); + + // Sum thread workspace + auto d0_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // Squared sum thread workspace + auto d1_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatCShuffle, + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_block_desc_mperblock_nperblock), + tensor_operation::element_wise::PassThrough, + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, + c_reduce_thread_data_idx_begin, + tensor_operation::element_wise::PassThrough{}}; + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + block_sync_lds(); + + // load from LDS and global, add bias + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_bias_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + acc_element_op(out, + c_reduce_thread_buf(i) + + static_cast(c0_thread_buf(i))); + c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias) + }); + + c0_add_thread_copy_global_to_vgpr.Run( + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_add_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // add + }); + + // layernorm + { + using ThreadwiseReduceD0 = + ThreadwiseReduction; + using ThreadwiseReduceD1 = + ThreadwiseReduction; + + const auto d0_zeroVal = + ThreadwiseReduceD0::Op::template GetIdentityValue(); + const auto d1_zeroVal = + ThreadwiseReduceD1::Op::template GetIdentityValue(); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d1_thread_buf(i) = d1_zeroVal; }); + + // reduce sum in VGPR + ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf); + + // reduce squared sum in VGPR + ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf); + + // reduce within workgroup + using BlockwiseReduce = PartitionedBlockwiseReduction< + FloatReduceAcc, + BlockSize, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K + Sequence<1, 0>, // ThreadClusterArrangeOrder + reduce::Add, + false>; + + static_for<0, mreduce_per_thread, 1>{}([&](auto i) { + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d0_thread_buf(i)); // blockwise reduced sum + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d1_thread_buf(i)); // blockwise reduced squared sum + }); + + // normalize + const index_t NRaw = + c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0] + .GetUpperLengths()[I1]; // TODO: proper handle + + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto dst_offset = + Number{}; + + constexpr auto src_offset = + Number{}; + + FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; + + FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; + FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; + FloatReduceAcc divisor_sqrt; + tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); + + c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; + }); + }); + + // scaling + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_gamma_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) *= + static_cast(c0_thread_buf(i)); // * gamma + }); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_beta_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // + beta + }); + + block_sync_lds(); + + c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf, + c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf); + + } // end layernorm + + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0_add + c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp new file mode 100644 index 00000000..3bb3774a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -0,0 +1,980 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v4_no_carry +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v4_no_carry() = default; + + __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_up_diff, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_up_diff[I0]; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod_wrw, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +__host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths) +{ + return Merge_v4_no_carry{low_lengths}; +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + // M0/M1/M1Padding + static constexpr auto M1PerBlock = Number{}; + static constexpr auto M0PerBlock = Number{}; + static constexpr auto M1Padding = Number{}; + + // N0/N1/N1Padding + static constexpr auto N1PerBlock = Number{}; + static constexpr auto N0PerBlock = Number{}; + static constexpr auto N1Padding = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_block_desc_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_b_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + M1Padding), + Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_b_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return a_block_desc_b_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_b_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_block_desc_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_b_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + N1Padding), + Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_b_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return b_block_desc_b_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return b_block_desc_b_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = math::integer_least_multiple( + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + // const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t num_loop = K0 / K0PerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + + // return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + constexpr index_t KPack = + math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + void* p_shared = static_cast(p_shared_block); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; // namespace ck + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp new file mode 100644 index 00000000..847bfd47 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp similarity index 67% rename from composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index f27fc73b..949d5648 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,24 +1,30 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE template __global__ void @@ -31,8 +37,12 @@ __global__ void const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -45,63 +55,35 @@ __global__ void a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_b_k0_m_k1_grid_desc, - const void CONSTANT* p_b_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc)); - const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc)); - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); - const auto c_block_cluster_adaptor = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); -} -#endif template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 { static constexpr auto I0 = Number<0>{}; @@ -151,6 +125,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -196,12 +172,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template __host__ __device__ static constexpr bool CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -225,31 +201,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { const bool has_main_k0_block_loop = K0 > K0PerBlock; @@ -304,44 +264,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 NRepeat, K1>; - return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc); } // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(KBatch), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); } using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); @@ -355,13 +286,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, const CBlockClusterAdaptor& c_block_cluster_adaptor) { - const auto a_grid_buf = make_dynamic_buffer( + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); @@ -370,7 +304,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const auto block_work_idx = c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0), + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1)))) + { + return; + } + const index_t k_batch_id = block_work_idx[I0]; + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); @@ -447,57 +390,63 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, a_b_k0_m_k1_block_desc, - make_multi_index(0, 0, 0, 0)); + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, b_b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0, 0)); + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -531,49 +480,38 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; - constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; - - auto a_block_buf = make_dynamic_buffer( + auto a_block_buf = make_dynamic_buffer( p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( + auto b_block_buf = make_dynamic_buffer( p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); // preload data into LDS { - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); } + // Initialize C + c_thread_buf.Clear(); + // main body - index_t k_block_data_begin = 0; if constexpr(HasMainKBlockLoop) { + index_t k0_block_data_begin = 0; + do { - a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); - a_blockwise_copy.RunRead( - a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); block_sync_lds(); - b_blockwise_copy.RunRead( - b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -582,8 +520,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - k_block_data_begin += K0PerBlock; - } while(k_block_data_begin < (K0 - K0PerBlock)); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); } // tail @@ -596,7 +534,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // output: register to global memory { constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); @@ -622,8 +560,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), @@ -648,6 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 FloatC, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + CElementwiseOperation, Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, @@ -664,17 +601,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])}; + n_thread_data_on_grid_idx[I2]), + c_element_op}; c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + c_grid_buf); } } -}; // namespace ck +}; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp new file mode 100644 index 00000000..84e1af0a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -0,0 +1,726 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t k_batch_id = block_work_idx[I0]; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp new file mode 100644 index 00000000..71bf05ce --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -0,0 +1,721 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatCShuffle, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_AK0_M_AK1, + typename BGridDesc_BK0_N_BK1, + typename CGridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1Value, + index_t BK1Value, + index_t MPerXdl, + index_t NPerXdl, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + constexpr auto max_lds_align = AK1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(AK0, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + constexpr auto max_lds_align = BK1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(BK0, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t k_pack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_shuffle_block_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp new file mode 100644 index 00000000..35f3bdef --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -0,0 +1,760 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp new file mode 100644 index 00000000..4e4ab9c9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -0,0 +1,799 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = p_c1_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename C1GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename Src2Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp new file mode 100644 index 00000000..1e52b405 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffer_desc, + DataType* const __restrict__ p_global, + DataType value) + +{ + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + constexpr auto I0 = Number<0>{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t thread_global_id = block_global_id * BlockSize + thread_local_id; + + StaticBuffer value_buf; + + value_buf(I0) = value; + + constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + auto global_buf = make_dynamic_buffer( + p_global, grid_1d_buffer_desc.GetElementSpaceSize()); + + if(thread_global_id < grid_1d_buffer_desc.GetElementSize()) + { + auto threadwise_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{}); + + threadwise_store.Run( + val_buff_desc, make_tuple(I0), value_buf, grid_1d_buffer_desc, global_buf); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp new file mode 100644 index 00000000..98b29ff8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, + const GridDesc_M_K out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_k, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); +}; + +template +struct GridwiseSoftmax_mk_to_mk +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (KThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k, + const GridDesc_M_K& out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer + out_thread_buf; + + StaticBuffer max_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // Normally, 0 as invalid element value is adequate since 0 makes no contribution to + // accumulated result. However, in stable softmax, all values 0s or not are subtracted by + // another value_max. As numbers become non-zero, effectively it allows invalid values to + // slip through and contribute to the accumulated result. + // + // The trick here is leveraging the fact that many math functions (add, sub, exp, ...) + // propagate NaNs when operands have NaNs involved. By initialiing invalid element value + // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still + // be identified as an invalid value. We can then discard the invalid values which + // originally failed the bound check during accumulation. This allows to ignore values that + // failed bound check even after multiple math manipulations. + // + // NOTE: reset coordinate after every step because the same threadwise copy will sweep + // through global memory 3 times back and forth + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2( + out_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3( + out_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + constexpr auto in_thread_copy_fwd_step = + make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); + constexpr auto in_thread_copy_bwd_step = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + /// + /// max(x) + /// + using BlockwiseMaxReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Max, + false, // param ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseMaxReduce = + ThreadwiseReduction>; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize()); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + /// + /// sum(exp(x - max(x))) + /// + using BlockwiseSumReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Add, + false, // ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseSumReduce = + ThreadwiseReduction>; + + reducedTiles = 0; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + // do element-wise pre-reduction operation + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); + }); + }); + + ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I)); + // block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + /// + /// softmax + /// + reducedTiles = 0; + if(float_equal_zero{}(beta)) + { + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + StaticBuffer + in_prior_dst_buf; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + threadwise_dst_load.Run(out_grid_desc_m_k, + out_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_prior_dst_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * in_prior_dst_buf(Number{}); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp new file mode 100644 index 00000000..6e7fbbc6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor); +} + +template +struct GridwiseUnaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0) + { + return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size) + { + const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector); + + return grid_size; + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_store_global_offset}; + + auto b_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + b_grid_desc_m0, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = b_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(b_thread_buf(Number{}), a_thread_buf(Number{})); + }); + + b_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + b_thread_buf, + b_grid_desc_m0, + b_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp new file mode 100644 index 00000000..188c62d9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template > +struct ThreadwiseReduction +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + using Op = OpReduce; + + template + __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); + }); + }); + }; +}; + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template < + typename AccDataType, + typename IndexDataType, + typename SrcThreadDesc_M_K, + typename DstThreadDesc_M, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> +struct ThreadwiseReductionWithIndex +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + template + __device__ static void Reduce(const SrcValueBufferType& src_val_buf, + const SrcIndexBufferType& src_idx_buf, + DstValueBufferType& dst_val_buf, + DstIndexBufferType& dst_idx_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_val_buf(Number{}), + src_val_buf[Number{}], + dst_idx_buf(Number{}), + src_idx_buf[Number{}]); + }); + }); + }; +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp similarity index 95% rename from composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 8b753810..94cdfe01 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,8 +1,10 @@ -#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP -#define CK_THREADWISE_CONTRACTION_DLOPS_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "math.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" namespace ck { @@ -25,9 +27,9 @@ template ::type = false> -struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 { - __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -124,9 +126,9 @@ template ::type = false> -struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -220,4 +222,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp new file mode 100644 index 00000000..e045e3b5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +// Assume: +// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at +// compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km_kn_mn_v3 +{ + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + + static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() && + BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0); + constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1); + constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2); + + constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + if constexpr((Ho % 2 == 0) && (Wo % 2 == 0)) + { + constexpr auto SubHW = 2; + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, SubHW>{}([&](auto h) { + static_for<0, Wo, SubHW>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b0_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t b1_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); + + constexpr index_t b2_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); + + constexpr index_t b3_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); + + constexpr index_t c0_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + constexpr index_t c1_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w + 1)); + + constexpr index_t c2_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w)); + + constexpr index_t c3_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + }); + }); + }); + }); + }); + } + else + { + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, 1>{}([&](auto h) { + static_for<0, Wo, 1>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t c_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp similarity index 85% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp index 20e9a5b3..0a1197a1 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -1,9 +1,11 @@ -#ifndef CK_THREADWISE_TENSOR_SET_HPP -#define CK_THREADWISE_TENSOR_SET_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" namespace ck { @@ -56,4 +58,3 @@ struct ThreadwiseTensorSliceSet_v1 }; } // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp similarity index 73% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 7e3f6b34..a50bb851 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,9 +1,12 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" namespace ck { @@ -50,11 +53,12 @@ template ::type = false> @@ -69,11 +73,15 @@ struct ThreadwiseTensorSliceTransfer_v1r3 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, - const Index& dst_slice_origin_idx) - : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + const Index& dst_slice_origin_idx, + const ElementwiseOperation& element_op) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), + element_op_{element_op} { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); } __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) @@ -81,16 +89,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template + template __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, const SrcBuffer& src_buf, const DstDesc& dst_desc, - DstBuffer& dst_buf, - const DstStepHacks& dst_step_hacks) + DstBuffer& dst_buf) { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); @@ -104,9 +108,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( @@ -115,141 +116,52 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr auto dst_scalar_step_in_vector = generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - dst_desc, forward_step_idx, dst_step_hacks[I0][i]); - }, - Number{}); + using SpaceFillingCurve = SpaceFillingCurve>; - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + typename vector_type_maker::type dst_vector; + using dst_vector_t = typename vector_type_maker::type::type; - return make_tensor_coordinate_step( - dst_desc, backward_step_idx, dst_step_hacks[I1][i]); - }, - Number{}); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - // loop over tensor and copy - static_ford{}([&](auto ordered_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_idx[I0]; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_access_idx[i] - : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - dst_scalar_per_access; - }(); - - typename vector_type_maker::type dst_vector; - - using dst_vector_t = - typename vector_type_maker::type::type; + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); // copy data from src_buf into dst_vector + // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); - dst_vector.template AsType()(i) = - type_convert{}(src_buf[Number{}]); + // apply type convert + dst_vector.template AsType()(i) = type_convert(v); }); const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // copy data from dst_vector into dst_buf - if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) - { - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); - } - else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) - { - dst_buf.template AtomicAdd( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); - } + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr + if constexpr(idx_1d.value != num_access - 1) { - StaticallyIndexedArray move_on_dim_; + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } - (); - - // move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); - } - } - }); }); // move dst coordinate back to slice origin (or not) @@ -262,82 +174,27 @@ struct ThreadwiseTensorSliceTransfer_v1r3 } } - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf, - const DstDesc& dst_desc, - DstBuffer& dst_buf) - { - constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto dst_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks); - } - __device__ static constexpr auto GetDstCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); + using SpaceFillingCurve = SpaceFillingCurve>; - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_lengths[I0] - 1; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in Run(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - return reset_dst_data_step; + return reset_step; + } } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason @@ -357,7 +214,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3 private: DstCoord dst_coord_; -}; // namespace ck + const ElementwiseOperation element_op_; +}; // namespace ThreadwiseTensorSliceTransfer_v1r3 // Assume: // 1. src: @@ -378,9 +236,14 @@ template ::type = false> struct ThreadwiseTensorSliceTransfer_v2 { + static_assert((InvalidElementAsNaN && !std::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; @@ -395,6 +258,8 @@ struct ThreadwiseTensorSliceTransfer_v2 { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -402,16 +267,12 @@ struct ThreadwiseTensorSliceTransfer_v2 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } - template + template __device__ void Run(const SrcDesc& src_desc, const SrcBuffer& src_buf, const DstDesc&, const DstSliceOriginIdx&, - DstBuffer& dst_buf, - const SrcStepHacks& src_step_hacks) + DstBuffer& dst_buf) { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); @@ -427,9 +288,6 @@ struct ThreadwiseTensorSliceTransfer_v2 constexpr auto dst_desc = remove_cvref_t{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( @@ -438,80 +296,19 @@ struct ThreadwiseTensorSliceTransfer_v2 constexpr auto src_scalar_step_in_vector = generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - src_desc, forward_step_idx, src_step_hacks[I0][i]); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - src_desc, backward_step_idx, src_step_hacks[I1][i]); - }, - Number{}); + using SpaceFillingCurve = SpaceFillingCurve>; // loop over tensor and copy - static_ford{}([&](auto ordered_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_idx[I0]; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_access_idx[i] - : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + static_for<0, num_access, 1>{}([&](auto idx_1d) { typename vector_type_maker::type src_vector; using src_vector_t = typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); @@ -526,41 +323,27 @@ struct ThreadwiseTensorSliceTransfer_v2 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + i * src_scalar_step_in_vector); - dst_buf(Number{}) = src_vector.template AsType()[i]; + if constexpr(InvalidElementAsNaN) + { + dst_buf(Number{}) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } }); - constexpr auto move_on_dim = [&]() constexpr + if constexpr(idx_1d.value != num_access - 1) { - StaticallyIndexedArray move_on_dim_; + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); } - (); - - // move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[dim_access_order[i]]); - } - } - }); }); // move src coordinate back to slice origin (or not) @@ -573,82 +356,27 @@ struct ThreadwiseTensorSliceTransfer_v2 } } - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) - { - constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto src_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks); - } - __device__ static constexpr auto GetSrcCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_lengths[I0] - 1; - - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + using SpaceFillingCurve = SpaceFillingCurve>; - // calculate src data index after last iteration in Run(), if it has not being reset by - // RunWrite() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - return reset_src_data_step; + return reset_step; + } } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason @@ -666,6 +394,25 @@ struct ThreadwiseTensorSliceTransfer_v2 move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + private: SrcCoord src_coord_; }; // namespace ck @@ -676,7 +423,7 @@ struct ThreadwiseTensorSliceTransfer_v2 // 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 4. Use thread buffer template {}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -729,8 +480,8 @@ struct ThreadwiseTensorSliceTransfer_v3 __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -794,7 +545,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; }); @@ -886,8 +637,8 @@ struct ThreadwiseTensorSliceTransfer_v3 __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -951,7 +702,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; }); @@ -983,7 +734,7 @@ struct ThreadwiseTensorSliceTransfer_v3 buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); dst_tmp_vector.template AsType()(i) = - type_convert{}(buffer_[Number{}]); + type_convert(buffer_[Number{}]); }); using dst_vector_t = typename decltype(dst_tmp_vector)::type; @@ -1095,7 +846,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; }); @@ -1155,7 +906,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_lengths[I0] - 1; - static_for<0, i, 1>{}([&](auto j) { + static_for<1, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; }); @@ -1244,7 +995,7 @@ struct ThreadwiseTensorSliceTransfer_v3 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; @@ -1290,7 +1041,8 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong!"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); } template {}([&](auto i) { dst_tmp_vector.template AsType()(i) = - type_convert{}(src_tmp_vector.template AsType()[i]); + type_convert(src_tmp_vector.template AsType()[i]); }); // copy data from dst_tmp_vector into dst_buf @@ -1433,4 +1185,3 @@ struct ThreadwiseTensorSliceTransfer_v4 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp new file mode 100644 index 00000000..005f35e9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -0,0 +1,793 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + SrcData src_v; + + src_element_op_(src_v, src_vector_container.template AsType()[i]); + + src_vector_container.template AsType()(i) = src_v; + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp new file mode 100644 index 00000000..6a73466e --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp @@ -0,0 +1,886 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); + using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); + using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r3( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)), + dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx); + dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + src_vector_container.template AsType()(i) = + src_element_op_(src_vector_container.template AsType()[i]); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_.template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const Dst0Desc& dst0_desc, + const Dst0Buffer& dst0_buf, + const Dst1Desc& dst1_desc, + const Dst1Buffer& dst1_buf) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + // apply DstElementwiseOperation on dst_vector_container + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + dst_vector_container.template AsType()(i) = + dst_element_op_(dst_vector_container.template AsType()[i]); + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + // TODO: BUG: should start at 1 + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Dst0Desc dst0_desc, + const Dst1Desc dst1_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step); + move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + StaticTensorTupleOfVectorBuffer + src_thread_scratch_; + + StaticTensorTupleOfVectorBuffer + dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 00000000..6e8a2393 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp similarity index 75% rename from composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp rename to include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index bbdaa5fa..f13da341 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,9 +1,12 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { @@ -13,7 +16,7 @@ namespace ck { // 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 4. Use thread buffer template // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation -struct ThreadwiseTensorSliceTransfer_v3r1 +struct ThreadwiseTensorSliceTransfer_v5r1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -44,7 +47,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, const Index& dst_slice_origin) @@ -76,8 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -244,8 +247,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); static_assert( @@ -351,7 +354,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_vector_desc.CalculateOffset(dst_vector_idx); dst_vector.template AsType()(Number{}) = - type_convert{}(buffer_[Number{}]); + type_convert(buffer_[Number{}]); }); using dst_vector_t = typename decltype(dst_vector)::type; @@ -602,175 +605,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; }; -// Assume: -// 1. src: -// 1. SrcDesc is known at compile-time -// 2. SrcBuffer is DynamicBuffer -// 3. src_ref_idx is known at run-time -// 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-step -// 2. dst: -// 1. DstDesc is known at compile-time -// 2. DstBuffer is StaticBuffer -// 3. DstOriginIdx is known at compile-time -// 4. use direct address calculation -// 3. vector access on src -template ::type = false> -struct ThreadwiseTensorSliceTransfer_v4r1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) - : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_for<0, nDim, 1>{}([](auto i) { - static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); - }); - } - - template - __device__ void Run(const SrcDesc&, - const SrcRefToOriginDisplacement&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); - - // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - - // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time - constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); - constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); - - // tensor descriptor for src_vector - constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; - - constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( - container_reverse_exclusive_scan( - container_reorder_given_new2old(src_vector_tensor_lengths, - SrcVectorTensorContiguousDimOrder{}), - math::multiplies{}, - I1), - SrcVectorTensorContiguousDimOrder{}); - - constexpr auto src_vector_desc = - make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), - sequence_to_tuple_of_number(src_vector_tensor_strides)); - - // access order and lengths - constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - static_ford{}([&](auto ordered_access_idx) { - // position in slice window - constexpr auto data_to_origin_disp_idx = - ordered_access_idx.ReorderGivenOld2New(dim_access_order) * - src_vector_tensor_lengths; - - // src coordinate at starting point of src_vector - constexpr auto src_ref_to_data_disp_idx = - src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - - constexpr auto src_ref_to_data_disp_coord_step = - make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); - - auto src_data_coord = src_ref_coord_; - - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); - - vector_type_maker_t src_vector; - - using src_vector_t = typename decltype(src_vector)::type; - - const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( - src_desc, src_data_coord); - - // copy data from src_buf into src_vector - src_vector.template AsType()(I0) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); - - // copy data from src_vector into dst_buf (also cast from SrcData to DstData) - static_ford{}([&](auto src_vector_idx_) { - constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); - - constexpr index_t src_vector_offset = - src_vector_desc.CalculateOffset(src_vector_idx); - - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); - - dst_buf(Number{}) = type_convert{}( - src_vector.template AsType()[Number{}]); - }); - }); - } - - template - __device__ void MoveSrcSliceWindow(const SrcDesc&, - const SrcSliceMoveStepIdx& src_slice_move_step_idx) - { - constexpr auto src_desc = SrcDesc{}; - - const auto src_slice_move_step_iter = - make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); - - move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); - } - - private: - SrcCoord src_ref_coord_; -}; - } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 00000000..9c91cd9c --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 00000000..68bc2726 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 00000000..0f5fb88b --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + src2_coord_(make_tensor_coordinate(src2_desc, src2_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetSrc2SliceOrigin(const Src2Desc& src2_desc, + const Index& src2_slice_origin_idx) + { + src2_coord_ = make_tensor_coordinate(src2_desc, src2_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using src2_vector_type = vector_type_maker_t; + using src2_vector_t = typename src2_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + const bool is_src2_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src2_desc, src2_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto src2_vector_container = src2_vector_type{ + src2_buf.template Get(src2_coord_.GetOffset(), is_src2_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i], + src2_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(Src2ResetCoordinateAfterRun) + { + const auto src2_reset_step = + make_tensor_coordinate_step(src2_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src2_desc, src2_coord_, src2_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, + const Index& src2_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src2ResetCoordinateAfterRun + ? src2_slice_origin_step_idx + : src2_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); + + move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + Src2Coord src2_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp new file mode 100644 index 00000000..2eb1b0ee --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = + SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + auto generate_vectors = [&](auto data_types) { + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + }; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(SrcDatas{}); + auto dst_vectors = generate_vectors(DstDatas{}); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + // apply pointwise function + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp similarity index 72% rename from composable_kernel/include/tensor_operation/xdlops_gemm.hpp rename to include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 10633f8f..eaf0f132 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,9 +1,11 @@ -#ifndef CK_XDLOPS_GEMM_HPP -#define CK_XDLOPS_GEMM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "common_header.hpp" -#include "math.hpp" -#include "amd_xdlops.hpp" +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_xdlops.hpp" namespace ck { @@ -12,18 +14,20 @@ enum struct MfmaInstr mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, // k reduction - mfma_f32_16x16x4xf32, // k reduction + mfma_f32_32x32x2xf32, + mfma_f32_16x16x4xf32, mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, - mfma_f32_32x32x8f16, // k reduction - mfma_f32_16x16x16f16, // k reduction - mfma_f32_32x32x2bf16, - mfma_f32_16x16x2bf16, - mfma_f32_4x4x2bf16, - mfma_f32_32x32x4bf16, // k reduction - mfma_f32_16x16x8bf16, // k reduction + mfma_f32_32x32x8f16, + mfma_f32_16x16x16f16, + mfma_f32_32x32x8bf16_1k, + mfma_f32_16x16x16bf16_1k, + mfma_f32_32x32x4bf16, + mfma_f32_16x16x8bf16, + mfma_i32_32x32x8i8, + mfma_i32_16x16x16i8, + mfma_f64_16x16x4f64 }; template @@ -250,9 +254,8 @@ struct mfma_type } }; -#if 0 template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -260,26 +263,38 @@ struct mfma_type static constexpr index_t num_threads_per_blk = 32; static constexpr index_t wave_size = 64; static constexpr index_t num_input_blks = 2; - static constexpr index_t num_output_blks = 2; + static constexpr index_t num_output_blks = 1; static constexpr index_t m_per_blk = 32; static constexpr index_t n_per_blk = 32; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); + intrin_mfma_f32_32x32x8bf16_1k::Run(a, b, reg_c); + } +}; - return intrin_mfma_f32_32x32x2bf16::run( - p_a, p_b, reg_c); +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16bf16_1k::Run(a, b, reg_c); } }; @@ -298,19 +313,10 @@ struct mfma_type static constexpr index_t k_per_blk = 2; static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); + intrin_mfma_f32_32x32x4bf16::Run(a, b, reg_c); } }; @@ -329,24 +335,37 @@ struct mfma_type static constexpr index_t k_per_blk = 2; static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); + intrin_mfma_f32_16x16x8bf16::Run(a, b, reg_c); + } +}; - return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); } }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -354,59 +373,40 @@ struct mfma_type static constexpr index_t num_threads_per_blk = 16; static constexpr index_t wave_size = 64; static constexpr index_t num_input_blks = 4; - static constexpr index_t num_output_blks = 4; + static constexpr index_t num_output_blks = 1; static constexpr index_t m_per_blk = 16; static constexpr index_t n_per_blk = 16; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); + intrin_mfma_i32_16x16x16i8::Run(a, b, reg_c); } }; template <> -struct mfma_type +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_per_blk = 1; - static constexpr index_t num_regs_per_blk = 4; - static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t group_size = 1; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk; + static constexpr index_t num_threads_per_blk = 16; static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; + static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk; static constexpr index_t num_output_blks = 1; - static constexpr index_t m_per_blk = 4; - static constexpr index_t n_per_blk = 64; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); + intrin_mfma_f64_16x16x4f64::Run(a, b, reg_c); } }; -#endif template struct MfmaSelector @@ -414,6 +414,12 @@ struct MfmaSelector template static constexpr auto GetMfma(); + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f64_16x16x4f64; + } + template <> static constexpr auto GetMfma() { @@ -498,77 +504,41 @@ struct MfmaSelector return MfmaInstr::mfma_f32_4x4x4f16; } -#if 0 - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif } template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif } template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; + return MfmaInstr::mfma_i32_32x32x8i8; } template <> - static constexpr auto GetMfma() + static constexpr auto GetMfma() { - return xdlops_info{}; + return MfmaInstr::mfma_i32_16x16x16i8; } -#endif static constexpr auto selected_mfma = mfma_type()>{}; - __host__ __device__ static constexpr void mfma_check() + __host__ __device__ constexpr MfmaSelector() { static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == selected_mfma.num_regs_per_blk, @@ -594,8 +564,6 @@ struct MfmaSelector "is_k_reduction wrong!"); } - __host__ __device__ constexpr MfmaSelector() { mfma_check(); } - static constexpr bool IsABroadcast() { static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); @@ -608,7 +576,7 @@ struct MfmaSelector selected_mfma.k_per_blk; } - static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } + static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } }; template @@ -644,17 +612,17 @@ struct XdlopsGemm static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); } - template + template __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); - const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); - const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); - const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); return transform_tensor_descriptor( - c_m0_n0_m1_n1_m2_n2_desc, + c_desc_m0_n0_m1_n1_m2_n2, make_tuple(make_pass_through_transform(M0), make_pass_through_transform(N0), make_pass_through_transform(M1), @@ -677,17 +645,57 @@ struct XdlopsGemm Sequence<7>{})); } + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + __device__ static constexpr index_t GetRegSizePerXdlops() { return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } + __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert(is_same::value || is_same::value || - is_same::value, - "base base_type must be float, half, ushort!"); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); @@ -769,7 +777,7 @@ struct XdlopsGemm static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetKPerThread(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() @@ -780,4 +788,3 @@ struct XdlopsGemm }; } // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp similarity index 66% rename from composable_kernel/include/utility/amd_address_space.hpp rename to include/ck/utility/amd_address_space.hpp index 24c95b27..9f152591 100644 --- a/composable_kernel/include/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -1,7 +1,9 @@ -#ifndef CK_AMD_ADDRESS_SPACE_HPP -#define CK_AMD_ADDRESS_SPACE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "c_style_pointer_cast.hpp" // Address Space for AMDGCN @@ -9,7 +11,7 @@ namespace ck { -enum AddressSpaceEnum_t +enum struct AddressSpaceEnum { Generic, Global, @@ -19,7 +21,7 @@ enum AddressSpaceEnum_t }; template -__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) +__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) { // cast a pointer in "Constant" address space (4) to "Generic" address space (0) // only c-style pointer cast seems be able to be compiled @@ -30,15 +32,14 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) } template -__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p) +__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) { // cast a pointer in "Generic" address space (0) to "Constant" address space (4) // only c-style pointer cast seems be able to be compiled #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (T CONSTANT*)p; // NOLINT(old-style-cast) + return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) #pragma clang diagnostic pop } } // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp similarity index 73% rename from composable_kernel/include/utility/amd_buffer_addressing.hpp rename to include/ck/utility/amd_buffer_addressing.hpp index 3df53bda..cc503cf0 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,6 +1,7 @@ -#ifndef CK_AMD_BUFFER_ADDRESSING_HPP -#define CK_AMD_BUFFER_ADDRESSING_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once #include "data_type.hpp" namespace ck { @@ -8,6 +9,8 @@ namespace ck { template union BufferResource { + __device__ constexpr BufferResource() : content{} {} + // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions int32x4_t content; @@ -31,7 +34,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_ return wave_buffer_resource.content; } -// load +// buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, @@ -50,11 +53,26 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); -__device__ int16_t +// buffer load i16 +__device__ bhalf_t llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +__device__ bhalf2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +__device__ bhalf4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + +// buffer load i32 __device__ int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, @@ -72,7 +90,8 @@ llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); -// half + +// buffer load fp16 __device__ half_t llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, @@ -91,7 +110,7 @@ llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); -// float +// buffer load fp32 __device__ float llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, @@ -110,7 +129,7 @@ llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); -// store +// buffer store i8 __device__ void llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, int32x4_t rsrc, @@ -132,13 +151,29 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); +// buffer store i16 __device__ void -llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, +llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); +__device__ void +llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + +// buffer store i32 __device__ void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, @@ -160,7 +195,7 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); -// half +// buffer store fp16 __device__ void llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, int32x4_t rsrc, @@ -181,7 +216,8 @@ llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); -// float + +// buffer store fp32 __device__ void llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, @@ -202,8 +238,16 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); -// atomic add -// int + +// buffer atomic-add fp16 +__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32_t vdata, int32x4_t rsrc, @@ -211,7 +255,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); -// float +// buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( float vdata, int32x4_t rsrc, @@ -219,6 +263,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); +// buffer atomic-add fp32 +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, @@ -228,6 +280,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -240,14 +293,14 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); } else if constexpr(N == 2) { const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); } else if constexpr(N == 4) { @@ -261,8 +314,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w 0); vector_type tmp; - tmp.AsType()(Number<0>{}) = as_type(f32_0); - tmp.AsType()(Number<1>{}) = as_type(f32_1); + tmp.AsType()(Number<0>{}) = bit_cast(f32_0); + tmp.AsType()(Number<1>{}) = bit_cast(f32_1); return tmp.AsType()(Number<0>{}); } @@ -323,7 +376,32 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); } } else if constexpr(is_same::value) @@ -374,7 +452,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } else if constexpr(N == 4) @@ -386,7 +464,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } else if constexpr(N == 8) @@ -408,7 +486,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } else if constexpr(N == 16) @@ -442,7 +520,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return as_type(tmp); + return bit_cast(tmp); #endif } } @@ -458,6 +536,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src (is_same::value && (N == 1 || N == 2)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -467,7 +546,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src // use fp32 store to mimic fp64 store if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_fp32x2(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -475,7 +554,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_store_fp32x4(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -537,6 +616,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 8) { +#if 0 vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], @@ -550,6 +630,56 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(half_t), 0); +#else + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(bhalf_t), + 0); } } else if constexpr(is_same::value) @@ -598,7 +728,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_addr_offset, 0); #else - llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -614,7 +744,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_addr_offset, 0); #else - llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -623,7 +753,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 8) { - llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -631,7 +761,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } else if constexpr(N == 16) { - llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -647,6 +777,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); @@ -705,6 +836,41 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ 0); } } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) @@ -762,16 +928,81 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ } } +template +__device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type -amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, - index_t src_thread_element_offset, - bool src_thread_element_valid, - index_t src_element_space_size) +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size); @@ -893,5 +1124,39 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr #endif } -} // namespace ck +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } #endif +} + +} // namespace ck diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp similarity index 93% rename from composable_kernel/include/utility/amd_inline_asm.hpp rename to include/ck/utility/amd_inline_asm.hpp index a2d9d5f0..82bf2a5e 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP @@ -211,14 +214,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 v_dot4_i32_i8 %1, %2, %4, %1\n \ " : "=v"(c0), "=v"(c1) - : "v"(as_type(a)), - "v"(as_type(b0)), - "v"(as_type(b1)), + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), "0"(c0), "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -244,20 +247,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, v_dot4_i32_i8 %3, %4, %8, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) - : "v"(as_type(a)), - "v"(as_type(b0)), - "v"(as_type(b1)), - "v"(as_type(b2)), - "v"(as_type(b3)), + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), + "v"(bit_cast(b2)), + "v"(bit_cast(b3)), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(as_type(a), as_type(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(as_type(a), as_type(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } diff --git a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp b/include/ck/utility/amd_llvm_intrinsic.hpp similarity index 67% rename from composable_kernel/include/utility/amd_llvm_intrinsic.hpp rename to include/ck/utility/amd_llvm_intrinsic.hpp index 841d48f8..01e77d7b 100644 --- a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp +++ b/include/ck/utility/amd_llvm_intrinsic.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_AMD_LLVM_INTRINSIC_HPP #define CK_AMD_LLVM_INTRINSIC_HPP diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp new file mode 100644 index 00000000..3e22c65c --- /dev/null +++ b/include/ck/utility/amd_xdlops.hpp @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_XDLOPS_HPP +#define CK_AMD_XDLOPS_HPP + +#include "data_type.hpp" + +namespace ck { + +// fp32 +template +struct intrin_mfma_f32_32x32x1f32; + +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// fp16 +template +struct intrin_mfma_f32_32x32x4f16; + +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// bfp16 +template +struct intrin_mfma_f32_32x32x8bf16_1k; + +template <> +struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16bf16_1k; + +template <> +struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4bf16; + +template <> +struct intrin_mfma_f32_32x32x4bf16<32, 32> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x8bf16; + +template <> +struct intrin_mfma_f32_16x16x8bf16<16, 16> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_i32_32x32x8i8; + +template <> +struct intrin_mfma_i32_32x32x8i8<32, 32> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x16i8; + +template <> +struct intrin_mfma_i32_16x16x16i8<16, 16> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f64_16x16x4f64; + +template <> +struct intrin_mfma_f64_16x16x4f64<16, 16> +{ + template + __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) + { +#ifdef __gfx90a__ + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array.hpp b/include/ck/utility/array.hpp similarity index 87% rename from composable_kernel/include/utility/array.hpp rename to include/ck/utility/array.hpp index 911cefd0..370a457f 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP @@ -49,7 +52,7 @@ template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { using data_type = remove_cvref_t; - return Array{{std::forward(x), std::forward(xs)...}}; + return Array{std::forward(x), std::forward(xs)...}; } // make empty array diff --git a/composable_kernel/include/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp similarity index 95% rename from composable_kernel/include/utility/array_multi_index.hpp rename to include/ck/utility/array_multi_index.hpp index f692fb51..9b8d5b95 100644 --- a/composable_kernel/include/utility/array_multi_index.hpp +++ b/include/ck/utility/array_multi_index.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP diff --git a/composable_kernel/include/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp similarity index 83% rename from composable_kernel/include/utility/c_style_pointer_cast.hpp rename to include/ck/utility/c_style_pointer_cast.hpp index 8acf5790..6e8b0081 100644 --- a/composable_kernel/include/utility/c_style_pointer_cast.hpp +++ b/include/ck/utility/c_style_pointer_cast.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp new file mode 100644 index 00000000..1378bbe4 --- /dev/null +++ b/include/ck/utility/common_header.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/array.hpp" +#include "ck/utility/container_helper.hpp" +#include "ck/utility/statically_indexed_array.hpp" +#include "ck/utility/container_element_picker.hpp" +#include "ck/utility/multi_index.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/utility/functional4.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/magic_division.hpp" +#include "ck/utility/c_style_pointer_cast.hpp" +#include "ck/utility/is_known_at_compile_time.hpp" +#include "ck/utility/transpose_vectors.hpp" +#include "ck/utility/inner_product.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/debug.hpp" + +#include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/generic_memory_space_atomic.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/synchronization.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/static_buffer.hpp" +#include "ck/utility/dynamic_buffer.hpp" + +// TODO: remove this +#if CK_USE_AMD_INLINE_ASM +#include "ck/utility/amd_inline_asm.hpp" +#endif + +#ifdef CK_USE_AMD_MFMA +#include "ck/utility/amd_xdlops.hpp" +#endif diff --git a/composable_kernel/include/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp similarity index 97% rename from composable_kernel/include/utility/container_element_picker.hpp rename to include/ck/utility/container_element_picker.hpp index 54915125..abc5185e 100644 --- a/composable_kernel/include/utility/container_element_picker.hpp +++ b/include/ck/utility/container_element_picker.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP diff --git a/composable_kernel/include/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp similarity index 97% rename from composable_kernel/include/utility/container_helper.hpp rename to include/ck/utility/container_helper.hpp index a7ed8ec0..c8b02bc5 100644 --- a/composable_kernel/include/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP @@ -373,19 +376,6 @@ set_container_subset(Tuple& y, Sequence picks, const Tuple& static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); } -template -__host__ __device__ constexpr auto to_tuple_of_number(const Container&) -{ - static_assert(is_known_at_compile_time::value, "wrong!"); - - return generate_tuple( - [&](auto i) { - constexpr index_t tmp = Container::At(i); - return Number{}; - }, - Container::Size()); -} - template __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) { diff --git a/composable_kernel/include/utility/data_type.hpp b/include/ck/utility/data_type.hpp similarity index 86% rename from composable_kernel/include/utility/data_type.hpp rename to include/ck/utility/data_type.hpp index 07eceb84..0e0d71a5 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,11 +1,14 @@ -#ifndef CK_FLOAT_TYPE_AMD_HPP -#define CK_FLOAT_TYPE_AMD_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "statically_indexed_array.hpp" +#pragma once + +#include "ck/utility/statically_indexed_array.hpp" namespace ck { -using half_t = _Float16; +using bhalf_t = ushort; +using half_t = _Float16; // vector_type template @@ -58,6 +61,18 @@ __host__ __device__ constexpr auto make_vector_type(Number) template struct scalar_type; +// is_scalar_type +template +struct is_scalar_type +{ + static constexpr bool value = (scalar_type>::vector_size == 1); +}; + +// has_same_scalar_type +template +using has_same_scalar_type = is_same>::type, + typename scalar_type>::type>; + template struct scalar_type { @@ -95,9 +110,9 @@ struct scalar_type }; template <> -struct scalar_type +struct scalar_type { - using type = ushort; + using type = bhalf_t; static constexpr index_t vector_size = 1; }; @@ -892,12 +907,12 @@ using half32_t = typename vector_type::type; using half64_t = typename vector_type::type; // bfp16 -using ushort2_t = typename vector_type::type; -using ushort4_t = typename vector_type::type; -using ushort8_t = typename vector_type::type; -using ushort16_t = typename vector_type::type; -using ushort32_t = typename vector_type::type; -using ushort64_t = typename vector_type::type; +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; // i32 using int32x2_t = typename vector_type::type; @@ -915,97 +930,71 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -// data type conversion -template -struct type_convert +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert(X x) { - template - __device__ T operator()(X x) const - { - return static_cast(x); - } -}; - -template <> -template <> -__device__ float type_convert::operator()(ushort x) const -{ - return bfloat16_to_float(x); + return static_cast(x); } +// convert bfp16 to fp32 template <> -template <> -__device__ ushort type_convert::operator()(float x) const +inline __host__ __device__ constexpr float type_convert(bhalf_t x) { - return float_to_bfloat16(x); + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + + return u.fp32; } -// TODO: deprecate this -template -struct inner_product_with_conversion +// convert fp32 to bfp16 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(float x) { - static constexpr auto convert = type_convert(); - - template - __device__ T operator()(typename vector_type::type a, - typename vector_type::type b) const + union { - const vector_type a_vector{a}; - const vector_type b_vector{b}; + float fp32; + uint32_t int32; + } u = {x}; - T acc = 0; - - static_for<0, N, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); - }); - - return acc; - } - - __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } - - __device__ T operator()(int8x4_t a, int8x4_t b) const + if(~u.int32 & 0x7f800000) { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 4, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even } - - __device__ T operator()(int8x8_t a, int8x8_t b) const + else if(u.int32 & 0xffff) { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 8, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN } - __device__ T operator()(int8x16_t a, int8x16_t b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 16, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; - } -}; + return uint16_t(u.int32 >> 16); +} template struct NumericLimits @@ -1015,6 +1004,11 @@ struct NumericLimits __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } + + __host__ __device__ static constexpr T QuietNaN() + { + return std::numeric_limits::quiet_NaN(); + } }; template <> @@ -1023,13 +1017,15 @@ struct NumericLimits static constexpr unsigned short binary_min = 0x0400; static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } - __host__ __device__ static constexpr half_t Min() { return as_type(binary_min); } + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } - __host__ __device__ static constexpr half_t Max() { return as_type(binary_max); } + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } - __host__ __device__ static constexpr half_t Lowest() { return as_type(binary_lowest); } + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } }; } // namespace ck -#endif diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp new file mode 100644 index 00000000..593bbb71 --- /dev/null +++ b/include/ck/utility/debug.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef UTILITY_DEBUG_HPP +#define UTILITY_DEBUG_HPP + +namespace ck { +namespace debug { + +namespace detail { +template +struct PrintAsType; + +template +struct PrintAsType::value>::type> +{ + using type = float; + __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast(p)); } +}; + +template <> +struct PrintAsType +{ + using type = float; + __host__ __device__ static void Print(const ck::half_t& p) + { + printf("%.3f ", static_cast(p)); + } +}; + +template +struct PrintAsType::value>::type> +{ + using type = int; + __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast(p)); } +}; +} // namespace detail + +// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer +// and the number of elements. Can optionally specify strides between elements and how many bytes' +// worth of data per row. +// +// Usage example: +// +// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize())); +// +template +__device__ void print_shared(T const* p_shared, index_t num_elements) +{ + constexpr index_t row_elements = row_bytes / sizeof(T); + static_assert((element_stride >= 1 && element_stride <= row_elements), + "element_stride should between [1, row_elements]"); + + index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + index_t tid = + (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x; + + __syncthreads(); + + if(tid == 0) + { + printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", + wgid, + row_bytes, + element_stride); + for(index_t i = 0; i < num_elements; i += row_elements) + { + printf("elem %5d: ", i); + for(index_t j = 0; j < row_elements; j += element_stride) + { + detail::PrintAsType::Print(p_shared[i + j]); + } + + printf("\n"); + } + printf("\n"); + } + + __syncthreads(); +} + +} // namespace debug +} // namespace ck + +#endif // UTILITY_DEBUG_HPP diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp new file mode 100644 index 00000000..ad886558 --- /dev/null +++ b/include/ck/utility/dynamic_buffer.hpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "enable_if.hpp" +#include "c_style_pointer_cast.hpp" +#include "amd_buffer_addressing.hpp" +#include "generic_memory_space_atomic.hpp" + +namespace ck { + +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + ElementSpaceSize element_space_size_; + T invalid_element_value_ = T{0}; + + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) + : p_data_{p_data}, element_space_size_{element_space_size} + { + } + + __host__ __device__ constexpr DynamicBuffer(T* p_data, + ElementSpaceSize element_space_size, + T invalid_element_value) + : p_data_{p_data}, + element_space_size_{element_space_size}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() + { + return BufferAddressSpace; + } + + __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, t_per_x>( + p_data_, i, is_valid_element, element_space_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::Set) + { + this->template Set(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd) + { + this->template AtomicAdd(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax) + { + this->template AtomicMax(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template Get(i, is_valid_element); + this->template Set(i, is_valid_element, x + tmp); + // tmp += x; + // this->template Set(i, is_valid_element, tmp); + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_STORE + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + +#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && + is_same>::type, int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x16_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + { + using scalar_t = typename scalar_type>::type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, int32_t> || + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) + bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; +#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else + { + if(is_valid_element) + { + atomic_add(c_style_pointer_cast(&p_data_[i]), x); + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) +{ + return DynamicBuffer{p, element_space_size}; +} + +template < + AddressSpaceEnum BufferAddressSpace, + typename T, + typename ElementSpaceSize, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) +{ + return DynamicBuffer{ + p, element_space_size, invalid_element_value}; +} + +} // namespace ck diff --git a/composable_kernel/include/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp similarity index 62% rename from composable_kernel/include/utility/enable_if.hpp rename to include/ck/utility/enable_if.hpp index 501e1bfc..297434b0 100644 --- a/composable_kernel/include/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,7 @@ -#ifndef CK_ENABLE_IF_HPP -#define CK_ENABLE_IF_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { @@ -10,4 +12,3 @@ template using enable_if_t = typename std::enable_if::type; } // namespace ck -#endif diff --git a/composable_kernel/include/utility/functional.hpp b/include/ck/utility/functional.hpp similarity index 91% rename from composable_kernel/include/utility/functional.hpp rename to include/ck/utility/functional.hpp index b84b617f..f5721a17 100644 --- a/composable_kernel/include/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,8 +1,10 @@ -#ifndef CK_FUNCTIONAL_HPP -#define CK_FUNCTIONAL_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "type.hpp" +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" namespace ck { @@ -113,4 +115,3 @@ template using conditional_t = typename conditional::type; } // namespace ck -#endif diff --git a/composable_kernel/include/utility/functional2.hpp b/include/ck/utility/functional2.hpp similarity index 84% rename from composable_kernel/include/utility/functional2.hpp rename to include/ck/utility/functional2.hpp index 371182a0..6f125ca4 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,8 +1,10 @@ -#ifndef CK_FUNCTIONAL2_HPP -#define CK_FUNCTIONAL2_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "functional.hpp" -#include "sequence.hpp" +#pragma once + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" namespace ck { @@ -45,4 +47,3 @@ struct static_for }; } // namespace ck -#endif diff --git a/composable_kernel/include/utility/functional3.hpp b/include/ck/utility/functional3.hpp similarity index 93% rename from composable_kernel/include/utility/functional3.hpp rename to include/ck/utility/functional3.hpp index 6a400f3c..06b67ef7 100644 --- a/composable_kernel/include/utility/functional3.hpp +++ b/include/ck/utility/functional3.hpp @@ -1,10 +1,13 @@ -#ifndef CK_FUNCTIONAL3_HPP -#define CK_FUNCTIONAL3_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "functional.hpp" -#include "functional2.hpp" -#include "sequence.hpp" -#include "multi_index.hpp" +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/multi_index.hpp" namespace ck { @@ -139,4 +142,3 @@ struct ford }; } // namespace ck -#endif diff --git a/composable_kernel/include/utility/functional4.hpp b/include/ck/utility/functional4.hpp similarity index 94% rename from composable_kernel/include/utility/functional4.hpp rename to include/ck/utility/functional4.hpp index b0396443..6eeaf15c 100644 --- a/composable_kernel/include/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp new file mode 100644 index 00000000..6a1ca966 --- /dev/null +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +__device__ X atomic_add(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_add(int32_t* p_dst, const int32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ uint32_t atomic_add(uint32_t* p_dst, const uint32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float atomic_add(float* p_dst, const float& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ double atomic_add(double* p_dst, const double& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_max explicit for +// each datatype. + +template +__device__ X atomic_max(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_max(int32_t* p_dst, const int32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ uint32_t atomic_max(uint32_t* p_dst, const uint32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float atomic_max(float* p_dst, const float& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ double atomic_max(double* p_dst, const double& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float2_t atomic_max(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicMax(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicMax(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +} // namespace ck diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp new file mode 100644 index 00000000..44ff4381 --- /dev/null +++ b/include/ck/utility/get_id.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { + +__host__ __device__ constexpr index_t get_warp_size() +{ + // warpSize is defined by HIP + return warpSize; +} + +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } + +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + +__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } + +__device__ index_t get_block_1d_id() { return blockIdx.x; } + +__device__ index_t get_grid_size() { return gridDim.x; } + +__device__ index_t get_block_size() { return blockDim.x; } + +} // namespace ck diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp new file mode 100644 index 00000000..01724587 --- /dev/null +++ b/include/ck/utility/ignore.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_IGNORE_HPP +#define CK_IGNORE_HPP + +// https://en.cppreference.com/w/cpp/utility/tuple/ignore + +namespace ck { + +namespace detail { +struct ignore_t +{ + template + constexpr void operator=(T&&) const noexcept + { + } +}; +} // namespace detail + +inline constexpr detail::ignore_t ignore; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp similarity index 90% rename from composable_kernel/include/utility/inner_product.hpp rename to include/ck/utility/inner_product.hpp index 51753acc..0f45ec17 100644 --- a/composable_kernel/include/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,6 +1,7 @@ -#ifndef CK_INNER_PRODUCT_HPP -#define CK_INNER_PRODUCT_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#pragma once #include "data_type.hpp" namespace ck { @@ -84,13 +85,12 @@ __device__ void inner_product(const half2_t& a, const h c = __builtin_amdgcn_sdot2(a, b, c, false); #endif #else - const auto convert = type_convert{}; - const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 2, 1>{}([&](auto i) { - c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -139,24 +139,23 @@ template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { -#if defined(CK_USE_DOT4_I32_I8) +#if defined(CK_USE_AMD_V_DOT4_I32_I8) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ " : "=v"(c) - : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); + : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); #else - c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); + c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); #endif #else - const auto convert = type_convert{}; - const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 4, 1>{}([&](auto i) { - c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -204,4 +203,3 @@ inner_product(const int8x16_t& a, const int8x16_t } } // namespace ck -#endif diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp new file mode 100644 index 00000000..9aab4e24 --- /dev/null +++ b/include/ck/utility/integral_constant.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +struct integral_constant +{ + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +template +__host__ __device__ constexpr auto operator+(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator-(integral_constant, integral_constant) +{ + static_assert(Y <= X, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator*(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator/(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator%(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +} // namespace ck diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp new file mode 100644 index 00000000..81981544 --- /dev/null +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "integral_constant.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +template +struct is_known_at_compile_time; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template +struct is_known_at_compile_time> +{ + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return container_reduce( + Tuple{}, + [](auto x, bool r) { + return is_known_at_compile_time>::value & r; + }, + true); + } + + static constexpr bool value = IsKnownAtCompileTime(); +}; + +} // namespace ck diff --git a/composable_kernel/include/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp similarity index 69% rename from composable_kernel/include/utility/magic_division.hpp rename to include/ck/utility/magic_division.hpp index 612aceea..a5e8e921 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -1,7 +1,9 @@ -#ifndef CK_MAGIC_DIVISION_HPP -#define CK_MAGIC_DIVISION_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" @@ -25,21 +27,30 @@ struct MagicDivision // uint32_t __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) { - // assert(divisior >= 1 && divisior <= INT32_MAX); - uint32_t shift = 0; - for(shift = 0; shift < 32; ++shift) + // WARNING: magic division is only applicable for division inside this range. + // You should use the return value of CalculateMagicNumbers, if division is not inside this + // range. The "else" logic below is to quiet down run-time error. + if(divisor >= 1 && divisor <= INT32_MAX) { - if((1U << shift) >= divisor) + uint32_t shift = 0; + for(shift = 0; shift < 32; ++shift) { - break; + if((1U << shift) >= divisor) + { + break; + } } - } - uint64_t one = 1; - uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; - // assert(multiplier <= 0xffffffffUL); + uint64_t one = 1; + uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + // assert(multiplier <= 0xffffffffUL); - return make_tuple(uint32_t(multiplier), shift); + return make_tuple(uint32_t(multiplier), shift); + } + else + { + return make_tuple(uint32_t(0), uint32_t(0)); + } } __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) @@ -111,26 +122,39 @@ struct MagicDivision } // magic division for uint32_t - __host__ __device__ static constexpr uint32_t + __device__ static constexpr uint32_t DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) { uint32_t tmp = __umulhi(dividend, multiplier); return (tmp + dividend) >> shift; } + __host__ static constexpr uint32_t + DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = static_cast(dividend) * multiplier >> 32; + return (tmp + dividend) >> shift; + } + // magic division for int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // non-negative for result to be correct // TODO: figure out how to do magic number divison for int32_t as dividended - __host__ __device__ static constexpr int32_t + __device__ static constexpr int32_t DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { - uint32_t dividend_u32 = as_type(dividend_i32); + uint32_t dividend_u32 = bit_cast(dividend_i32); uint32_t tmp = __umulhi(dividend_u32, multiplier); return (tmp + dividend_u32) >> shift; } + + __host__ static constexpr int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = static_cast(dividend_u32) * multiplier >> 32; + return (tmp + dividend_u32) >> shift; + } }; } // namespace ck - -#endif diff --git a/composable_kernel/include/utility/math.hpp b/include/ck/utility/math.hpp similarity index 91% rename from composable_kernel/include/utility/math.hpp rename to include/ck/utility/math.hpp index 48438e61..0cfc2f7d 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -1,7 +1,9 @@ -#ifndef CK_MATH_HPP -#define CK_MATH_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "config.hpp" +#pragma once + +#include "ck/ck.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" @@ -142,6 +144,24 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) return min(x, min(ys...)); } +// disallow implicit type casting +template +__device__ T exp(T x); + +// TODO: add f16 support using v_exp_f16 + +template <> +__device__ float exp(float x) +{ + return __expf(x); +} + +template <> +__device__ double exp(double x) +{ + return exp(x); +} + // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { @@ -212,5 +232,3 @@ struct less } // namespace math } // namespace ck - -#endif diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp new file mode 100644 index 00000000..fc264117 --- /dev/null +++ b/include/ck/utility/math_v2.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" + +namespace ck { +namespace math { + +// math functions for the host, some are implemented by calling C++ std functions + +static inline __host__ float abs(float x) { return std::abs(x); }; + +static inline __host__ double abs(double x) { return std::abs(x); }; + +static inline __host__ int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __host__ int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __host__ half_t abs(half_t x) +{ + uint16_t xx = ck::bit_cast(x); + + uint16_t abs_xx = xx & 0x7fff; + + half_t abs_x = ck::bit_cast(abs_xx); + + return abs_x; +}; + +static inline __host__ bool isnan(float x) { return std::isnan(x); }; + +static inline __host__ bool isnan(double x) { return std::isnan(x); }; + +static inline __host__ bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +static inline __host__ bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +static inline __host__ bool isnan(half_t x) +{ + uint16_t xx = ck::bit_cast(x); + + return (xx & 0x7FFF) > 0x7C00; +}; + +static inline __host__ float sqrt(float x) { return std::sqrt(x); }; + +static inline __host__ double sqrt(double x) { return std::sqrt(x); }; + +// math functions for the HIP kernel, some are implemented by calling hip builtin functions + +static inline __device__ float abs(float x) { return ::abs(x); }; + +static inline __device__ double abs(double x) { return ::abs(x); }; + +static inline __device__ int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __device__ int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __device__ half_t abs(half_t x) { return ::__habs(x); }; + +static inline __device__ bool isnan(float x) { return ::isnan(x); }; + +static inline __device__ bool isnan(double x) { return ::isnan(x); }; + +static inline __device__ bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +static inline __device__ bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); }; + +static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; + +static inline __device__ double sqrt(double x) { return ::sqrt(x); }; + +} // namespace math +} // namespace ck diff --git a/include/ck/utility/multi_index.hpp b/include/ck/utility/multi_index.hpp new file mode 100644 index 00000000..1d544c09 --- /dev/null +++ b/include/ck/utility/multi_index.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "common_header.hpp" + +#if CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX +#include "array_multi_index.hpp" +#else +#include "statically_indexed_array_multi_index.hpp" +#endif diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp new file mode 100644 index 00000000..f3ca6b61 --- /dev/null +++ b/include/ck/utility/number.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_NUMBER_HPP +#define CK_NUMBER_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +using Number = integral_constant; + +template +using LongNumber = integral_constant; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/print.hpp b/include/ck/utility/print.hpp similarity index 81% rename from composable_kernel/include/utility/print.hpp rename to include/ck/utility/print.hpp index d7d58bbb..eed1ca42 100644 --- a/composable_kernel/include/utility/print.hpp +++ b/include/ck/utility/print.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_PRINT_HPP #define CK_PRINT_HPP diff --git a/include/ck/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp new file mode 100644 index 00000000..aceef7b2 --- /dev/null +++ b/include/ck/utility/reduction_common.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_enums.hpp" + +namespace ck { + +struct float_equal_one +{ + template + __host__ __device__ inline bool operator()(T x) + { + return x <= static_cast(1.0f) and x >= static_cast(1.0f); + }; +}; + +struct float_equal_zero +{ + template + __host__ __device__ inline bool operator()(T x) + { + return x <= static_cast(0.0f) and x >= static_cast(0.0f); + }; +}; + +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + +} // namespace ck diff --git a/include/ck/utility/reduction_enums.hpp b/include/ck/utility/reduction_enums.hpp new file mode 100644 index 00000000..67856331 --- /dev/null +++ b/include/ck/utility/reduction_enums.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +enum struct ReduceTensorOp +{ + ADD = 0, + MUL = 1, + MIN = 2, + MAX = 3, + AMAX = 4, + AVG = 5, + NORM1 = 6, + NORM2 = 7, + // MUL_NO_ZEROS = 8, +}; + +enum struct NanPropagation +{ + NOT_PROPAGATE_NAN = 0, + PROPAGATE_NAN = 1, +}; + +enum struct ReduceTensorIndices +{ + NO_INDICES = 0, + FLATTENED_INDICES = 1, +}; + +enum struct IndicesType +{ + INDICES_32BIT = 0, + INDICES_64BIT = 1, + INDICES_16BIT = 2, + INDICES_8BIT = 3, +}; + +} // namespace ck diff --git a/include/ck/utility/reduction_functions_accumulate.hpp b/include/ck/utility/reduction_functions_accumulate.hpp new file mode 100644 index 00000000..724e5599 --- /dev/null +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" + +namespace ck { +namespace detail { + +// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs) +template +struct AccumulateWithNanIgnore +{ + __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + if(!ck::math::isnan(currVal)) + { + ReduceOperation{}(accuVal, currVal); + } + }; +}; + +template +struct AccumulateWithNanCheck; + +// Does not check for NaN; does not guarantee NaNs be propagated to result +// e.g., given that max(a, b) = a > b ? a : b +// then max(NaN, 1) returns 1 +// max(1, NaN) returns NaN +// since any comparison involving NaNs returns false +template +struct AccumulateWithNanCheck +{ + // cppcheck-suppress constParameter + __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + ReduceOperation{}(accuVal, currVal); + }; +}; + +// Check for NaN; guarantees NaNs be propagated to result +template +struct AccumulateWithNanCheck +{ + __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + using ck::math::isnan; + + if(isnan(currVal)) + { + accuVal = currVal; + } + else + { + ReduceOperation{}(accuVal, currVal); + }; + }; +}; + +template +struct AccumulateWithIndexAndNanCheck; + +template +struct AccumulateWithIndexAndNanCheck +{ + __host__ __device__ static inline void + // cppcheck-suppress constParameter + Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) + { + bool changed = false; + + ReduceOperation{}(accuVal, currVal, changed); + + if(changed) + accuIndex = currIndex; + }; +}; + +template +struct AccumulateWithIndexAndNanCheck +{ + // The method is called when the ReduceOperation is indexable and the user asked for indices + __host__ __device__ static inline void Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) + { + using ck::math::isnan; + + if(isnan(currVal)) + { + accuVal = currVal; + accuIndex = currIndex; + } + else + { + bool changed = false; + + ReduceOperation{}(accuVal, currVal, changed); + + if(changed) + accuIndex = currIndex; + } + }; +}; + +} // namespace detail +} // namespace ck diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp new file mode 100644 index 00000000..0e09cc03 --- /dev/null +++ b/include/ck/utility/reduction_operator.hpp @@ -0,0 +1,292 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" + +namespace ck { + +namespace reduce { + +// Every binary operator used in reduction is represented by a templated functor class. Each functor +// class must provide at least +// three members: +// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary +// operator, "identity element" is the unique +// element in the algebraic space that doesn't affect the value of other elements +// when operated against them, and the concept is similar to zero vector in +// vector space +// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). +// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this +// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() -- +// the first argument of the operator must be both an input & output, and the corresponding variable +// usually stores +// the accumulated result of many operator() calls; the second argument is only an +// input. For indexable binary +// operator, the second version of operator() has third argument (which is an +// output) to indicate whether the +// accumulated value (the first argument) has changed, in which case the recorded +// accumulated index also need be +// changed. + +struct Add +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Add accumulator!"); + + a = a + b; + } +}; + +struct SquaredAdd +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + + a = a + b * b; + } +}; + +struct Mul +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(1.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Mul accumulator!"); + + a = a * b; + } +}; + +struct Max +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return NumericLimits::Lowest(); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + + if(a < b) + a = b; + } + + template + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + + if(a < b) + { + a = b; + changed = true; + } + } +}; + +struct Min +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return NumericLimits::Max(); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_min to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + + if(a > b) + a = b; + } + + template + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + + if(a > b) + { + a = b; + changed = true; + } + } +}; + +struct AMax +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + + if(a < b) + a = b; + } + + template + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + + if(a < b) + { + a = b; + changed = true; + } + } +}; + +template +constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) +{ + T result = ck::type_convert(0.0f); + + if(operation == InMemoryDataOperationEnum::AtomicMax) + result = ck::NumericLimits::Lowest(); + + return (result); +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = false; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value; +}; + +} // namespace reduce +} // namespace ck diff --git a/composable_kernel/include/utility/sequence.hpp b/include/ck/utility/sequence.hpp similarity index 96% rename from composable_kernel/include/utility/sequence.hpp rename to include/ck/utility/sequence.hpp index b35999d5..97b59722 100644 --- a/composable_kernel/include/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,10 +1,12 @@ -#ifndef CK_SEQUENCE_HPP -#define CK_SEQUENCE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "type.hpp" -#include "functional.hpp" -#include "math.hpp" +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/math.hpp" namespace ck { @@ -241,7 +243,13 @@ struct arithmetic_sequence_gen } }; - using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type1 = Sequence<>; + + static constexpr bool kHasContent = + (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); + + using type = typename conditional::type; }; // uniform sequence @@ -606,6 +614,12 @@ struct sequence_map_inverse SeqMap::Size()>::type; }; +template +__host__ __device__ constexpr bool operator==(Sequence, Sequence) +{ + return ((Xs == Ys) && ...); +} + template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { @@ -876,5 +890,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f) return flag; } +template +using sequence_merge_t = typename sequence_merge::type; + +template +using uniform_sequence_gen_t = typename uniform_sequence_gen::type; + } // namespace ck -#endif diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp similarity index 82% rename from composable_kernel/include/utility/sequence_helper.hpp rename to include/ck/utility/sequence_helper.hpp index 88d7da63..db25c27e 100644 --- a/composable_kernel/include/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -1,7 +1,9 @@ -#ifndef CK_SEQUENCE_HELPER_HPP -#define CK_SEQUENCE_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "tuple.hpp" +#pragma once + +#include "ck/utility/tuple.hpp" namespace ck { @@ -33,4 +35,3 @@ __host__ __device__ constexpr auto to_sequence(Tuple...>) } } // namespace ck -#endif diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp new file mode 100644 index 00000000..638eefa3 --- /dev/null +++ b/include/ck/utility/static_buffer.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATIC_BUFFER_HPP +#define CK_STATIC_BUFFER_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +// static buffer for scalar +template // TODO remove this bool, no longer needed +struct StaticBuffer : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + __host__ __device__ constexpr StaticBuffer() : base{} {} + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + + // read access + template + __host__ __device__ constexpr const T& operator[](Number i) const + { + return base::operator[](i); + } + + // write access + template + __host__ __device__ constexpr T& operator()(Number i) + { + return base::operator()(i); + } + + __host__ __device__ void Clear() + { + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); + } +}; + +// static buffer for vector +template ::value, bool>::type = false> +struct StaticBufferTupleOfVector + : public StaticallyIndexedArray, NumOfVector> +{ + using V = typename vector_type::type; + using base = StaticallyIndexedArray, NumOfVector>; + + static constexpr auto s_per_v = Number{}; + static constexpr auto num_of_v_ = Number{}; + + __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + + // Get S + // i is offset of S + template + __host__ __device__ constexpr const S& operator[](Number i) const + { + constexpr auto i_v = i / s_per_v; + constexpr auto i_s = i % s_per_v; + + return base::operator[](i_v).template AsType()[i_s]; + } + + // Set S + // i is offset of S + template + __host__ __device__ constexpr S& operator()(Number i) + { + constexpr auto i_v = i / s_per_v; + constexpr auto i_s = i % s_per_v; + + return base::operator()(i_v).template AsType()(i_s); + } + + // Get X + // i is offset of S, not X. i should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr auto GetAsType(Number i) const + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + return base::operator[](i_v).template AsType()[i_x]; + } + + // Set X + // i is offset of S, not X. i should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr void SetAsType(Number i, X x) + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + base::operator()(i_v).template AsType()(i_x) = x; + } + + // Get read access to vector_type V + // i is offset of S, not V. i should be aligned to V + template + __host__ __device__ constexpr const auto& GetVectorTypeReference(Number i) const + { + static_assert(i % s_per_v == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + + return base::operator[](i_v); + } + + // Get write access to vector_type V + // i is offset of S, not V. i should be aligned to V + template + __host__ __device__ constexpr auto& GetVectorTypeReference(Number i) + { + static_assert(i % s_per_v == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + + return base::operator()(i_v); + } + + __host__ __device__ void Clear() + { + constexpr index_t NumScalars = NumOfVector * ScalarPerVector; + + static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); + } +}; + +template +__host__ __device__ constexpr auto make_static_buffer(Number) +{ + return StaticBuffer{}; +} + +template +__host__ __device__ constexpr auto make_static_buffer(LongNumber) +{ + return StaticBuffer{}; +} + +} // namespace ck +#endif diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp new file mode 100644 index 00000000..3438776f --- /dev/null +++ b/include/ck/utility/statically_indexed_array.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP +#define CK_STATICALLY_INDEXED_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +namespace detail { +template +struct tuple_concat; + +template +struct tuple_concat, Tuple> +{ + using type = Tuple; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = + typename tuple_concat::type, + typename StaticallyIndexedArrayImpl::type>::type; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = Tuple<>; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = Tuple; +}; +} // namespace detail + +template +using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl::type; + +template +__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return StaticallyIndexedArray(x, static_cast(xs)...); +} + +// make empty StaticallyIndexedArray +template +__host__ __device__ constexpr auto make_statically_indexed_array() +{ + return StaticallyIndexedArray(); +} + +template +struct StaticallyIndexedArray_v2 +{ + __host__ __device__ constexpr StaticallyIndexedArray_v2() = default; + + __host__ __device__ static constexpr index_t Size() { return N; } + + // read access + template + __host__ __device__ constexpr const auto& At(Number) const + { + static_assert(I < N, "wrong! out of range"); + + return data_[I]; + } + + // write access + template + __host__ __device__ constexpr auto& At(Number) + { + static_assert(I < N, "wrong! out of range"); + + return data_[I]; + } + + // read access + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + // write access + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + T data_[N]; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp similarity index 92% rename from composable_kernel/include/utility/statically_indexed_array_multi_index.hpp rename to include/ck/utility/statically_indexed_array_multi_index.hpp index 9e96f06d..bab5aebf 100644 --- a/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP @@ -93,6 +96,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) return r; } +// MultiIndex = MultiIndex * index_t +template +__host__ __device__ constexpr auto operator*(const Tuple& x, index_t a) +{ + return a * x; +} + template __host__ __device__ void print_multi_index(const Tuple& x) { diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp new file mode 100644 index 00000000..caa23cb5 --- /dev/null +++ b/include/ck/utility/synchronization.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { + +__device__ void block_sync_lds() +{ +#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __syncthreads(); +#endif +} + +} // namespace ck diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp new file mode 100644 index 00000000..d469dec8 --- /dev/null +++ b/include/ck/utility/thread_group.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "get_id.hpp" + +namespace ck { + +template +struct ThisThreadBlock +{ + static constexpr index_t kNumThread_ = ThreadPerBlock; + + __device__ static constexpr index_t GetNumOfThread() { return kNumThread_; } + + __device__ static constexpr bool IsBelong() { return true; } + + __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } +}; + +} // namespace ck diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp new file mode 100644 index 00000000..9f204e27 --- /dev/null +++ b/include/ck/utility/transpose_vectors.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "statically_indexed_array.hpp" +#include "data_type.hpp" + +namespace ck { + +template ::value, bool>::type = false> +struct transpose_vectors; + +// transpose fp16 2x2 +__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1) +{ +#if 0 + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + const vector_type vx0{x0}, vx1{x1}; + vector_type vy0, vy1; + + vy0.template AsType()(I0) = vx0.template AsType()[I0]; + vy0.template AsType()(I1) = vx1.template AsType()[I0]; + + vy1.template AsType()(I0) = vx0.template AsType()[I1]; + vy1.template AsType()(I1) = vx1.template AsType()[I1]; + + y0 = vy0.template AsType()[I0]; + y1 = vy1.template AsType()[I0]; +#else + asm volatile("\n \ + v_pack_b32_f16 %0, %1, %2 \n \ + " + : "=v"(y0) + : "v"(x0), "v"(x1)); + + asm volatile("\n \ + v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \ + " + : "=v"(y1) + : "v"(x0), "v"(x1)); +#endif +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = half_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + + // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // reference to 2 half2_t data from vx_tuple + const auto& x_s2_0 = vx_tuple[ix].template AsType()[iy / I2]; + const auto& x_s2_1 = vx_tuple[ix + I1].template AsType()[iy / I2]; + + // reference to 2 half2_t data from vy_tuple + auto& y_s2_0 = vy_tuple(iy).template AsType()(ix / I2); + auto& y_s2_1 = vy_tuple(iy + I1).template AsType()(ix / I2); + + // transpose + transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1); + }); + }); + } +}; + +// transpose int8 4x4 +__device__ void transpose_int8_4x4(const int8x4_t& x0, + const int8x4_t& x1, + const int8x4_t& x2, + const int8x4_t& x3, + int8x4_t& y0, + int8x4_t& y1, + int8x4_t& y2, + int8x4_t& y3) +{ + int32_t t0, t1; + int32_t z0, z1, z2, z3; + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + // clang-format off + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast(x1)), "v"(bit_cast(x0)), "s"(m0)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast(x3)), "v"(bit_cast(x2)), "s"(m0)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z0) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m1)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z1) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m2)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast(x1)), "v"(bit_cast(x0)), "s"(m3)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast(x3)), "v"(bit_cast(x2)), "s"(m3)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z2) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m1)); + asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z3) : "v"(bit_cast(t1)), "v"(bit_cast(t0)), "s"(m2)); + // clang-format on + + y0 = bit_cast(z0); + y1 = bit_cast(z1); + y2 = bit_cast(z2); + y3 = bit_cast(z3); +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = int8_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // reference to 4 int8 data from vx_tuple + const auto& x_s4_0 = vx_tuple[ix].template AsType()[iy / I4]; + const auto& x_s4_1 = vx_tuple[ix + I1].template AsType()[iy / I4]; + const auto& x_s4_2 = vx_tuple[ix + I2].template AsType()[iy / I4]; + const auto& x_s4_3 = vx_tuple[ix + I3].template AsType()[iy / I4]; + + // reference to 4 int8 data from vy_tuple + auto& y_s4_0 = vy_tuple(iy).template AsType()(ix / I4); + auto& y_s4_1 = vy_tuple(iy + I1).template AsType()(ix / I4); + auto& y_s4_2 = vy_tuple(iy + I2).template AsType()(ix / I4); + auto& y_s4_3 = vy_tuple(iy + I3).template AsType()(ix / I4); + + // transpose + transpose_int8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3); + }); + }); + } +}; + +} // namespace ck diff --git a/composable_kernel/include/utility/tuple.hpp b/include/ck/utility/tuple.hpp similarity index 51% rename from composable_kernel/include/utility/tuple.hpp rename to include/ck/utility/tuple.hpp index 70f4d77d..d8664be5 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -1,10 +1,12 @@ -#ifndef CK_TUPLE_HPP -#define CK_TUPLE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "sequence.hpp" -#include "type.hpp" -#include "enable_if.hpp" +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/enable_if.hpp" namespace ck { @@ -17,59 +19,77 @@ struct TupleElementKey }; template -struct TupleElement +struct TupleElementKeyData { - __host__ __device__ constexpr TupleElement() = default; + using DataType = Data; + +#if 0 // workaround compiler complaint about implicitly-deleted default constructor + __host__ __device__ constexpr TupleElementKeyData() = default; +#else + __host__ __device__ constexpr TupleElementKeyData() : mData{} {} +#endif template >, TupleElement>::value, + typename enable_if, TupleElementKeyData>::value, bool>::type = false> - __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) + __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward(v)) { } - Data mData; + DataType mData; }; +// for read access of tuple element template -__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement& x) +__host__ __device__ constexpr const Data& +get_tuple_element_data_reference(const TupleElementKeyData& x) { return static_cast(x.mData); } +// for write access of tuple element template -__host__ __device__ constexpr Data& get_tuple_element(TupleElement& x) +__host__ __device__ constexpr Data& +get_tuple_element_data_reference(TupleElementKeyData& x) { return x.mData; } // TODO: not sure the use of reference is correct template -__host__ __device__ constexpr Data&& get_tuple_element(TupleElement&& x) +__host__ __device__ constexpr Data&& +get_tuple_element_data_reference(TupleElementKeyData&& x) { return static_cast(x.mData); } +// for infering type of tuple element +template +__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData& x) +{ + return std::forward(x.mData); +} + template struct TupleImpl; template -struct TupleImpl, Xs...> : TupleElement, Xs>... +struct TupleImpl, Xs...> : TupleElementKeyData, Xs>... { __host__ __device__ constexpr TupleImpl() = default; template >, TupleImpl>::value, + !is_same, TupleImpl>::value, bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) - : TupleElement, Xs>(std::forward(y))... + : TupleElementKeyData, Xs>(std::forward(y))... { } template = 2, bool>::type = false> __host__ __device__ constexpr TupleImpl(Ys&&... ys) - : TupleElement, Xs>(std::forward(ys))... + : TupleElementKeyData, Xs>(std::forward(ys))... { static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), "wrong! inconsistent size"); @@ -78,15 +98,15 @@ struct TupleImpl, Xs...> : TupleElement, Xs> __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } template - __host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey) const + __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey) const { - return get_tuple_element>(*this); + return get_tuple_element_data_reference>(*this); } template - __host__ __device__ constexpr auto& GetElementByKey(TupleElementKey) + __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) { - return get_tuple_element>(*this); + return get_tuple_element_data_reference>(*this); } }; @@ -101,8 +121,7 @@ struct Tuple : detail::TupleImpl>, Tuple>::value, + typename enable_if, Tuple>::value, bool>::type = false> __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { @@ -117,26 +136,30 @@ struct Tuple : detail::TupleImpl __host__ __device__ constexpr const auto& At(Number) const { static_assert(I < base::Size(), "wrong! out of range"); - return base::GetElementByKey(detail::TupleElementKey{}); + return base::GetElementDataByKey(detail::TupleElementKey{}); } + // write access template __host__ __device__ constexpr auto& At(Number) { static_assert(I < base::Size(), "wrong! out of range"); - return base::GetElementByKey(detail::TupleElementKey{}); + return base::GetElementDataByKey(detail::TupleElementKey{}); } + // read access template __host__ __device__ constexpr const auto& operator[](Number i) const { return At(i); } + // write access template __host__ __device__ constexpr auto& operator()(Number i) { @@ -156,11 +179,43 @@ struct Tuple : detail::TupleImpl +struct Tuple<> +{ + __host__ __device__ constexpr Tuple() = default; + + __host__ __device__ static constexpr index_t Size() { return 0; } + + template + __host__ __device__ constexpr auto operator=(const T&) + { + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template +struct tuple_element +{ + // type should keep the cv/ref qualifier of original tuple element + using type = decltype(detail::get_tuple_element_data>(TTuple{})); +}; + +template +using tuple_element_t = typename tuple_element::type; + template __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { return Tuple...>(std::forward(xs)...); } +// https://en.cppreference.com/w/cpp/utility/tuple/tie +template +constexpr Tuple tie(Args&... args) noexcept +{ + return {args...}; +} + } // namespace ck -#endif diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp similarity index 69% rename from composable_kernel/include/utility/tuple_helper.hpp rename to include/ck/utility/tuple_helper.hpp index 55a79d25..6f5b142a 100644 --- a/composable_kernel/include/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -1,27 +1,13 @@ -#ifndef CK_TUPLE_HELPER_HPP -#define CK_TUPLE_HELPER_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include "functional4.hpp" #include "tuple.hpp" namespace ck { -template -struct is_known_at_compile_time> -{ - __host__ __device__ static constexpr bool IsKnownAtCompileTime() - { - return container_reduce( - Tuple{}, - [](auto x, bool r) { - return is_known_at_compile_time>::value & r; - }, - true); - } - - static constexpr bool value = IsKnownAtCompileTime(); -}; - template __host__ __device__ constexpr auto generate_tuple(F&& f, Number) { @@ -29,6 +15,24 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +template +__host__ __device__ constexpr auto generate_tie(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return tie(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) +template +__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& tx, + const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return Tuple{std::forward(zs)...}; }, + tx, + ty); +} + namespace detail { template @@ -75,4 +79,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, } } // namespace ck -#endif diff --git a/composable_kernel/include/utility/type.hpp b/include/ck/utility/type.hpp similarity index 58% rename from composable_kernel/include/utility/type.hpp rename to include/ck/utility/type.hpp index 89a2bdbd..90b9df29 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,8 +1,11 @@ -#ifndef CK_TYPE_HPP -#define CK_TYPE_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "integral_constant.hpp" -#include "enable_if.hpp" +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/enable_if.hpp" namespace ck { @@ -16,6 +19,9 @@ struct is_same : public integral_constant { }; +template +inline constexpr bool is_same_v = is_same::value; + template using remove_reference_t = typename std::remove_reference::type; @@ -26,26 +32,21 @@ template using remove_cvref_t = remove_cv_t>; template -inline constexpr bool is_pointer_v = std::is_pointer::value; +using remove_pointer_t = typename std::remove_pointer::type; template -struct is_known_at_compile_time; +inline constexpr bool is_pointer_v = std::is_pointer::value; -template <> -struct is_known_at_compile_time +template ::type = false> +__host__ __device__ constexpr Y bit_cast(const X& x) { - static constexpr bool value = false; -}; +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST + Y y; -template -struct is_known_at_compile_time> -{ - static constexpr bool value = true; -}; + __builtin_memcpy(&y, &x, sizeof(X)); -template ::type = false> -__host__ __device__ constexpr Y as_type(X x) -{ + return y; +#else union AsType { X x; @@ -53,7 +54,7 @@ __host__ __device__ constexpr Y as_type(X x) }; return AsType{x}.y; +#endif } } // namespace ck -#endif diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt new file mode 100644 index 00000000..90873fdd --- /dev/null +++ b/library/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(src/tensor_operation_instance/gpu) +add_subdirectory(src/utility) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp new file mode 100644 index 00000000..97ce3dca --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchedGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g_m_k, + const Tensor& b_g_k_n, + Tensor& c_g_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g_m_k_{a_g_m_k}, + b_g_k_n_{b_g_k_n}, + c_g_m_n_{c_g_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g_m_k_; + const Tensor& b_g_k_n_; + Tensor& c_g_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm::Argument; + + float Run(const Argument& arg) + { + auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { + const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); + arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); + + v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + } + + float v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g_m_n_(g, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_gmk_gkn_gmn, + arg.c_g_m_n_.mDesc.GetLengths()[0], + arg.c_g_m_n_.mDesc.GetLengths()[1], + arg.c_g_m_n_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g_m_k, + const Tensor& b_g_k_n, + Tensor& c_g_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp new file mode 100644 index 00000000..ce0e3374 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// FIXME: support arbitrary elementwise operation for A/B/C +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct ReferenceCGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_real_{a_m_k_real}, + a_m_k_imag_{a_m_k_imag}, + b_k_n_real_{b_k_n_real}, + b_k_n_imag_{b_k_n_imag}, + c_m_n_real_{c_m_n_real}, + c_m_n_imag_{c_m_n_imag}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_real_; + const Tensor& a_m_k_imag_; + const Tensor& b_k_n_real_; + const Tensor& b_k_n_imag_; + Tensor& c_m_n_real_; + Tensor& c_m_n_imag_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceCGemm::Argument; + + float Run(const Argument& arg) + { + const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + + if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) + { + throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM"); + } + + auto f_mk_kn_mn_real = [&](auto m, auto n) { + float v_c_real = 0; + + for(std::size_t k = 0; k < K; ++k) + { + float v_a_real = ck::type_convert(arg.a_m_k_real_(m, k)); + float v_a_imag = ck::type_convert(arg.a_m_k_imag_(m, k)); + float v_b_real = ck::type_convert(arg.b_k_n_real_(k, n)); + float v_b_imag = ck::type_convert(arg.b_k_n_imag_(k, n)); + + v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag; + } + + arg.c_m_n_real_(m, n) = v_c_real; + }; + + auto f_mk_kn_mn_imag = [&](auto m, auto n) { + float v_c_imag = 0; + + for(std::size_t k = 0; k < K; ++k) + { + float v_a_real = ck::type_convert(arg.a_m_k_real_(m, k)); + float v_a_imag = ck::type_convert(arg.a_m_k_imag_(m, k)); + float v_b_real = ck::type_convert(arg.b_k_n_real_(k, n)); + float v_b_imag = ck::type_convert(arg.b_k_n_imag_(k, n)); + + v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real; + } + + arg.c_m_n_imag_(m, n) = v_c_imag; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn_real, + arg.c_m_n_real_.mDesc.GetLengths()[0], + arg.c_m_n_real_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_mk_kn_mn_imag, + arg.c_m_n_imag_.mDesc.GetLengths()[0], + arg.c_m_n_imag_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real, + c_m_n_imag, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceCGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp new file mode 100644 index 00000000..4dc47255 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template = 1 && NumDimSpatial <= 3, bool>::type = false> +struct ReferenceConvBwdWeight : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{in_n_c_hi_wi}, + weight_{wei_k_c_y_x}, + output_{out_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& input_; + Tensor& weight_; + const Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdWeight::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumDimSpatial == 1) + { + constexpr auto I0 = Number<0>{}; + auto f_kcx = [&](auto k, auto c, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[I0]) + + ck::type_convert(x * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.mDesc.GetLengths()[2]) + { + float v_out; + float v_in; + + arg.out_element_op_(v_out, + ck::type_convert(arg.output_(n, k, wo))); + arg.in_element_op_(v_in, + ck::type_convert(arg.input_(n, c, wi))); + + v_acc += v_out * v_in; + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[I0]) + + ck::type_convert(y * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[I1]) + + ck::type_convert(x * + arg.conv_dilations_[I1]) - + ck::type_convert(arg.in_left_pads_[I1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[3]) + { + float v_out; + float v_in; + + arg.out_element_op_( + v_out, ck::type_convert(arg.output_(n, k, ho, wo))); + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(n, c, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcyx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2], + arg.weight_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_) + { + auto di = + ck::type_convert(do_ * arg.conv_strides_[I0]) + + ck::type_convert(z * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[I1]) + + ck::type_convert(y * + arg.conv_dilations_[I1]) - + ck::type_convert(arg.in_left_pads_[I1]); + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4]; + ++wo) + { + auto wi = + ck::type_convert(wo * + arg.conv_strides_[I2]) + + ck::type_convert( + x * arg.conv_dilations_[I2]) - + ck::type_convert(arg.in_left_pads_[I2]); + if(di >= 0 && + ck::type_convert(di) < + arg.input_.mDesc.GetLengths()[2] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[4]) + { + float v_out; + float v_in; + + arg.out_element_op_(v_out, + ck::type_convert( + arg.output_(n, k, do_, ho, wo))); + arg.in_element_op_( + v_in, + ck::type_convert(arg.input_(n, c, di, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, z, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kczyx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2], + arg.weight_.mDesc.GetLengths()[3], + arg.weight_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdWeight" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp new file mode 100644 index 00000000..a244b6ec --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template = 1 && NumDimSpatial <= 3, bool>::type = false> +struct ReferenceConvBwdData : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{input}, + weight_{weight}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + Tensor& input_; + const Tensor& weight_; + const Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdData::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumDimSpatial == 1) + { + auto f_ncw = [&](auto n, auto c, auto wi) { + std::size_t K = arg.weight_.mDesc.GetLengths()[0]; + std::size_t X = arg.weight_.mDesc.GetLengths()[2]; + std::size_t Wo = arg.output_.mDesc.GetLengths()[2]; + + AccDataType v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = ck::type_convert(wi) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(x * arg.conv_dilations_[0]); + if(w_tmp % arg.conv_strides_[0] == 0) + { + auto wo = ck::type_convert(w_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_out = 0; + AccDataType v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert(arg.output_(n, k, wo))); + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(k, c, x))); + + v_acc += v_out * v_wei; + } + } + } + } + + arg.in_element_op_(v_acc, v_acc); + arg.input_(n, c, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncw, + arg.input_.mDesc.GetLengths()[0], + arg.input_.mDesc.GetLengths()[1], + arg.input_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t K = arg.weight_.mDesc.GetLengths()[0]; + std::size_t Y = arg.weight_.mDesc.GetLengths()[2]; + std::size_t X = arg.weight_.mDesc.GetLengths()[3]; + + std::size_t Ho = arg.output_.mDesc.GetLengths()[2]; + std::size_t Wo = arg.output_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = ck::type_convert(hi) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(y * arg.conv_dilations_[0]); + if(h_tmp % arg.conv_strides_[0] == 0) + { + auto ho = ck::type_convert(h_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + ck::type_convert(wi) + + ck::type_convert(arg.in_left_pads_[1]) - + ck::type_convert(x * + arg.conv_dilations_[1]); + if(w_tmp % arg.conv_strides_[1] == 0) + { + auto wo = ck::type_convert(w_tmp) / + ck::type_convert( + arg.conv_strides_[1]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_out = 0; + AccDataType v_wei = 0; + + arg.out_element_op_(v_out, + ck::type_convert( + arg.output_(n, k, ho, wo))); + arg.wei_element_op_(v_wei, + ck::type_convert( + arg.weight_(k, c, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + + AccDataType v_in; + arg.in_element_op_(v_in, v_acc); + arg.input_(n, c, hi, wi) = ck::type_convert(v_in); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.input_.mDesc.GetLengths()[0], + arg.input_.mDesc.GetLengths()[1], + arg.input_.mDesc.GetLengths()[2], + arg.input_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { + std::size_t K = arg.weight_.mDesc.GetLengths()[0]; + std::size_t Z = arg.weight_.mDesc.GetLengths()[2]; + std::size_t Y = arg.weight_.mDesc.GetLengths()[3]; + std::size_t X = arg.weight_.mDesc.GetLengths()[4]; + + std::size_t Do = arg.output_.mDesc.GetLengths()[2]; + std::size_t Ho = arg.output_.mDesc.GetLengths()[3]; + std::size_t Wo = arg.output_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + auto d_tmp = ck::type_convert(di) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(z * arg.conv_dilations_[0]); + if(d_tmp % arg.conv_strides_[0] == 0) + { + auto do_ = ck::type_convert(d_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(do_ >= 0 && ck::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = + ck::type_convert(hi) + + ck::type_convert(arg.in_left_pads_[1]) - + ck::type_convert(y * + arg.conv_dilations_[1]); + if(h_tmp % arg.conv_strides_[1] == 0) + { + auto ho = ck::type_convert(h_tmp) / + ck::type_convert( + arg.conv_strides_[1]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + ck::type_convert(wi) + + ck::type_convert( + arg.in_left_pads_[2]) - + ck::type_convert( + x * arg.conv_dilations_[2]); + if(w_tmp % arg.conv_strides_[2] == 0) + { + auto wo = + ck::type_convert(w_tmp) / + ck::type_convert( + arg.conv_strides_[2]); + if(wo >= 0 && + ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_out = 0; + AccDataType v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert( + arg.output_( + n, k, do_, ho, wo))); + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_(k, c, z, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + } + } + } + + AccDataType v_in; + arg.in_element_op_(v_in, v_acc); + arg.input_(n, c, di, hi, wi) = ck::type_convert(v_in); + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.input_.mDesc.GetLengths()[0], + arg.input_.mDesc.GetLengths()[1], + arg.input_.mDesc.GetLengths()[2], + arg.input_.mDesc.GetLengths()[3], + arg.input_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{input, + weight, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdData" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp new file mode 100644 index 00000000..f79352d6 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// +// @brief Reference implementation for forward convolution. +// +// @paragraph Supports both NCHW as well as NHWC formats (and their respective +// counterparts for weight and output) as long as tensor descriptor +// lengths is in NCHW. +// +// @tparam InDataType Input tensor data type. +// @tparam WeiDataType Weights tensor data type. +// @tparam OutDataType Output tensor data type. +// @tparam InElementwiseOperation Functor for input tensor elementwise +// operation. +// @tparam WeiElementwiseOperation Functor for weights tensor elementwise +// operation. +// @tparam NumDimSpatial Number of spatial dimensions. +// +// FIXME: only support NDimSpatial = 1 to 3; only support NCHW and NHWC layout. +// Need to be more general +template = 1 && NumDimSpatial <= 3, bool>::type = false> +struct ReferenceConvFwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{input}, + weight_{weight}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& input_; + const Tensor& weight_; + Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd::Argument; + + // FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name + float Run(const Argument& arg) + { + // tensor descriptor in NCHW/KXYC/NKHW dimensional order + HostTensorDescriptor in_desc = arg.input_.mDesc; + HostTensorDescriptor wei_desc = arg.weight_.mDesc; + HostTensorDescriptor out_desc = arg.output_.mDesc; + + // input + if constexpr(is_same_v) + { + in_desc = transpose_host_tensor_descriptor_given_new2old( + arg.input_.mDesc, std::vector{0, 2, 1}); + } + else if constexpr(is_same_v) + { + in_desc = transpose_host_tensor_descriptor_given_new2old( + arg.input_.mDesc, std::vector{0, 3, 1, 2}); + } + else if constexpr(is_same_v) + { + in_desc = transpose_host_tensor_descriptor_given_new2old( + arg.input_.mDesc, std::vector{0, 4, 1, 2, 3}); + } + + // weight + if constexpr(is_same_v) + { + wei_desc = transpose_host_tensor_descriptor_given_new2old( + arg.weight_.mDesc, std::vector{0, 2, 1}); + } + else if constexpr(is_same_v) + { + wei_desc = transpose_host_tensor_descriptor_given_new2old( + arg.weight_.mDesc, std::vector{0, 3, 1, 2}); + } + else if constexpr(is_same_v) + { + wei_desc = transpose_host_tensor_descriptor_given_new2old( + arg.weight_.mDesc, std::vector{0, 4, 1, 2, 3}); + } + + // output + if constexpr(is_same_v) + { + out_desc = transpose_host_tensor_descriptor_given_new2old( + arg.output_.mDesc, std::vector{0, 2, 1}); + } + else if constexpr(is_same_v) + { + out_desc = transpose_host_tensor_descriptor_given_new2old( + arg.output_.mDesc, std::vector{0, 3, 1, 2}); + } + else if constexpr(is_same_v) + { + out_desc = transpose_host_tensor_descriptor_given_new2old( + arg.output_.mDesc, std::vector{0, 4, 1, 2, 3}); + } + + if constexpr(NumDimSpatial == 1) + { + auto f_ncw = [&](auto n, auto k, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) + { + for(std::size_t x = 0; x < wei_desc.GetLengths()[2]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[0]) + + ck::type_convert(x * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + + if(wi >= 0 && + ck::type_convert(wi) < in_desc.GetLengths()[2]) + { + float v_in; + float v_wei; + + // FIXME hacky + arg.in_element_op_( + v_in, + ck::type_convert( + arg.input_ + .mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)])); + + // FIXME hacky + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_ + .mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)])); + + v_acc += v_in * v_wei; + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + + // FIXME hacky + arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, wo})] = + ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_ncw, + out_desc.GetLengths()[0], + out_desc.GetLengths()[1], + out_desc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < wei_desc.GetLengths()[2]; ++y) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + + for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + + if(hi >= 0 && + ck::type_convert(hi) < in_desc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < in_desc.GetLengths()[3]) + { + float v_in; + float v_wei; + + // FIXME hacky + arg.in_element_op_( + v_in, + ck::type_convert( + arg.input_.mData[in_desc.GetOffsetFromMultiIndex( + n, c, hi, wi)])); + + // FIXME hacky + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex( + k, c, y, x)])); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + + // FIXME hacky + arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, ho, wo})] = + ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_nchw, + out_desc.GetLengths()[0], + out_desc.GetLengths()[1], + out_desc.GetLengths()[2], + out_desc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) + { + for(std::size_t z = 0; z < wei_desc.GetLengths()[2]; ++z) + { + auto di = + ck::type_convert(d_o * arg.conv_strides_[0]) + + ck::type_convert(z * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t y = 0; y < wei_desc.GetLengths()[3]; ++y) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[1]) + + ck::type_convert(y * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + for(std::size_t x = 0; x < wei_desc.GetLengths()[4]; ++x) + { + auto wi = + ck::type_convert(wo * + arg.conv_strides_[2]) + + ck::type_convert(x * + arg.conv_dilations_[2]) - + ck::type_convert(arg.in_left_pads_[2]); + if(di >= 0 && + ck::type_convert(di) < + in_desc.GetLengths()[2] && + hi >= 0 && + ck::type_convert(hi) < + in_desc.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < in_desc.GetLengths()[4]) + { + float v_in; + float v_wei; + + // FIXME hacky + arg.in_element_op_( + v_in, + ck::type_convert( + arg.input_.mData[in_desc.GetOffsetFromMultiIndex( + n, c, di, hi, wi)])); + + // FIXME hacky + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex( + k, c, z, y, x)])); + + v_acc += v_in * v_wei; + } + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + + // FIXME hacky + arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, d_o, ho, wo})] = + ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_nchw, + out_desc.GetLengths()[0], + out_desc.GetLengths()[1], + out_desc.GetLengths()[2], + out_desc.GetLengths()[3], + out_desc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override + { + return NumDimSpatial >= 1 && NumDimSpatial <= 3; + } + + static auto MakeArgument(const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{input, + weight, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp new file mode 100644 index 00000000..be22003f --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template +struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc, static_cast(arg.bias_k_(k))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp new file mode 100644 index 00000000..f949f27f --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template +struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + resi_n_k_ho_wo_{resi_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + const Tensor& resi_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation_Add::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, + v_acc, + static_cast(arg.bias_k_(k)), + static_cast(arg.resi_n_k_ho_wo_(n, k, ho, wo))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation_Add" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp new file mode 100644 index 00000000..6728bb1f --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp new file mode 100644 index 00000000..c77d22f4 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBias2D : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c0_m_n_{c0_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& c0_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBias2D::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType a = 0; + AccDataType b = 0; + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + arg.a_element_op_(a, ck::type_convert(arg.a_m_k_(m, k))); + arg.b_element_op_(b, ck::type_convert(arg.b_k_n_(k, n))); + acc += a * b; + } + + CDataType cast_acc = static_cast(acc); + arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n)); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBias2D" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp new file mode 100644 index 00000000..7dfc3c1e --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivation::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc, static_cast(arg.c0_n_(n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp new file mode 100644 index 00000000..99102a40 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivationAdd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + c1_m_n_{c1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + const Tensor& c1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivationAdd::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, + v_acc, + static_cast(arg.c0_n_(n)), + static_cast(arg.c1_m_n_(m, n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivationAdd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp new file mode 100644 index 00000000..b1e72459 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct ReferenceGemmLayernorm : public device::BaseOperator +{ + using ReferenceGemmInstance = ReferenceGemm; + + template + static void RunLayernorm(Tensor& result, + const Tensor& acc, // MxN + const Tensor& gamma, // 1xN + const Tensor& beta, // 1xN + const InDataType epsilon = 1e-5) + { + assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] && + acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]); + + size_t M = acc.mDesc.GetLengths()[0]; + size_t N = acc.mDesc.GetLengths()[1]; + + Tensor avg_acc_sq(HostTensorDescriptor(std::vector({M}))); + Tensor avg_acc(HostTensorDescriptor(std::vector({M}))); + Tensor acc_layernorm(acc); + + // reduce N dim + for(size_t i = 0; i < M; i++) + { + ComputeDataType sum_acc_sq = 0; + ComputeDataType sum_acc = 0; + for(size_t j = 0; j < N; j++) + { + sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j); + sum_acc += acc_layernorm(i, j); + } + avg_acc_sq(i) = sum_acc_sq / N; + avg_acc(i) = sum_acc / N; + } + + // normalize + acc_layernorm.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) = + (self(idx[0], idx[1]) - avg_acc(idx[0])) / + sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon); + }); + + // affine + acc_layernorm.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]); + }); + + // cast + result = acc_layernorm.template CopyAsType(); + } + + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n_bias, // 1xN + const Tensor& c0_m_n_add, // MxN + const Tensor& c0_n_gamma, // 1xN + const Tensor& c0_n_beta, // 1xN + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + const CDataType epsilon = 1e-5) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_bias_{c0_n_bias}, + c0_m_n_add_{c0_m_n_add}, + c0_n_gamma_{c0_n_gamma}, + c0_n_beta_{c0_n_beta}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op}, + epsilon_{epsilon} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_bias_; + const Tensor& c0_m_n_add_; + const Tensor& c0_n_gamma_; + const Tensor& c0_n_beta_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + + const CDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + // using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + Tensor acc_m_n(arg.c_m_n_.mDesc); + acc_m_n.GenerateTensorValue(GeneratorTensor_1{0}); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_, + arg.b_k_n_, + acc_m_n, + arg.a_element_op_, + arg.b_element_op_, + element_wise::PassThrough{}); + + // gemm + ref_invoker.Run(ref_argument); + + // activation(acc + bias) + acc_m_n.ForEach([&](auto& self, auto idx) { + AccDataType out; + arg.acc_element_op_(out, acc_m_n(idx[0], idx[1]) + arg.c0_n_bias_(idx[1])); + self(idx[0], idx[1]) = out; + }); + + // add from other layers + acc_m_n.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) += arg.c0_m_n_add_(idx[0], idx[1]); + }); + + // layernorm + RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_gamma_, arg.c0_n_beta_); + + // elementwise op + arg.c_m_n_.ForEach([&](auto& self, auto idx) { + arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1])); + }); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n_bias, // 1xN + const Tensor& c0_m_n_add, // 1xN + const Tensor& c0_n_gamma, // 1xN + const Tensor& c0_n_beta, // 1xN + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + const CDataType epsilon = 1e-5) + { + return Argument{a_m_k, + b_k_n, + c_m_n, + c0_n_bias, + c0_m_n_add, + c0_n_gamma, + c0_n_beta, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + epsilon}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp new file mode 100644 index 00000000..bfc6986d --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceSoftmax : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const std::vector sm_reduce_dims) + : in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) + { + // std::cout << "debug: scalar dims: "; + for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++) + { + if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == + sm_reduce_dims.end()) + { + sm_scalar_dims_.push_back(i); + // std::cout << i << ", "; + } + } + // std::cout << std::endl; + } + + const Tensor& in_; + Tensor& out_; + AccDataType alpha_; + AccDataType beta_; + std::vector sm_reduce_dims_; + std::vector sm_scalar_dims_; // dim after internal max/sum reduction + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + std::vector scalar_lengths; + for(index_t dim : arg.sm_scalar_dims_) + { + scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); + } + + Tensor reduce_max(scalar_lengths); + reduce_max.GenerateTensorValue( + GeneratorTensor_1{std::numeric_limits::lowest()}); + Tensor reduce_sum(scalar_lengths); + reduce_sum.GenerateTensorValue(GeneratorTensor_1{0}); + + auto to_sm_scalar_idx = [&](auto idx) { + std::vector sm_scalar_idx; + for(index_t dim : arg.sm_scalar_dims_) + { + sm_scalar_idx.push_back(idx[dim]); + } + return sm_scalar_idx; + }; + + arg.in_.ForEach([&](auto& self, auto idx) { + reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)), + static_cast(self(idx))); + }); + + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; + + Tensor in_stable(arg.in_.mDesc); + in_stable.ForEach([&](auto& self, auto idx) { + // numerator = exp(x - max(x)) + self(idx) = std::exp(static_cast(arg.in_(idx)) - + reduce_max(to_sm_scalar_idx(idx))); + }); + + // LogRangeAsType(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl; + + in_stable.ForEach([&](auto& self, auto idx) { + // denominator = sum(exp(x - max(x))) + reduce_sum(to_sm_scalar_idx(idx)) += self(idx); + }); + + // LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") << + // std::endl; + + arg.out_.ForEach([&](auto& self, auto idx) { + self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + + arg.beta_ * self(idx); + }); + + // LogRangeAsType(std::cout << "out: ", arg.out_.mData, ",") << std::endl; + // reduction along reduce dims + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") + // << std::endl; + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const std::vector sm_reduce_dims) + { + return Argument{in, out, alpha, beta, sm_reduce_dims}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceSoftmax" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp new file mode 100644 index 00000000..df4fca65 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef NAIVE_CONV_FWD_HPP +#define NAIVE_CONV_FWD_HPP + +namespace ck { +namespace ref { + +/* + * \brief naive implementation of 3D convolution. Layout is (NDHWC, KZYXC, NDHWK). + * + * \param N number of batches + * \param K number of filters + * \param C number of channels of weight + * \param (Di, Hi, Wi) depth, height and width dimension of data + * \param (Z, Y, X) depth, height and width dimensions of weights + * \param (Do, Ho, Wo) depth, height and width dimension of output + * \param (stride_z, stride_y, stride_x) strides + * \param (dilation_z, dilation_y, dilation_x) dilations + * \param (pad_z, pad_y, pad_x) pads + */ +template +__global__ void naive_conv_fwd_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in, + const TWei* __restrict__ p_wei, + TOut* __restrict__ p_out, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x) +{ + const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const index_t num_threads = blockDim.x * gridDim.x; + const long_index_t output_length = N * Do * Ho * Wo * K; + + const index_t out_strides[] = {Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K}; + const index_t in_strides[] = {Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C}; + const index_t wei_strides[] = {Z * Y * X * C, Y * X * C, X * C, C}; + + constexpr auto in_op = InElementwiseOperation{}; + constexpr auto wei_op = WeiElementwiseOperation{}; + constexpr auto out_op = OutElementwiseOperation{}; + + TIn in_val; + TWei wei_val; + TOut out_val; + + for(long_index_t ii = tid; ii < output_length; ii += num_threads) + { + const index_t n = ii / out_strides[0]; + index_t k = ii - n * out_strides[0]; + const index_t dO = k / out_strides[1]; + k -= dO * out_strides[1]; + const index_t ho = k / out_strides[2]; + k -= ho * out_strides[2]; + const index_t wo = k / out_strides[3]; + k -= wo * out_strides[3]; + + TAcc acc = static_cast(0); + + const TIn* in_n = p_in + static_cast(n) * in_strides[0]; + const TWei* wei_k = p_wei + static_cast(k) * wei_strides[0]; + + for(index_t z = 0; z < Z; ++z) + { + index_t di = stride_z * dO - pad_z + dilation_z * z; + const TIn* in_n_di = in_n + di * in_strides[1]; + const TWei* wei_k_z = wei_k + z * wei_strides[1]; + + for(index_t y = 0; y < Y; ++y) + { + index_t hi = stride_y * ho - pad_y + dilation_y * y; + const TIn* in_n_di_hi = in_n_di + hi * in_strides[2]; + const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2]; + + for(index_t x = 0; x < X; ++x) + { + index_t wi = stride_x * wo - pad_x + dilation_x * x; + const TIn* in_n_di_hi_wi = in_n_di_hi + wi * in_strides[3]; + const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3]; + + if(di >= 0 && di < Di && hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) + { + for(index_t c = 0; c < C; ++c) + { + in_op(in_val, in_n_di_hi_wi[c]); + wei_op(wei_val, wei_k_z_y_x[c]); + acc += in_val * wei_val; + } + } + } + } + } + + out_op(out_val, static_cast(acc)); + p_out[ii] = out_val; + } +} +} // namespace ref +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp new file mode 100644 index 00000000..20df1b36 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/functional2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +void add_device_operation_instances(std::vector>& op_instances, + const NewOpInstances& new_op_instances) +{ + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + const auto new_op_instance = std::get(new_op_instances); + + using NewOpInstance = remove_cvref_t; + + static_assert(std::is_base_of_v, + "wrong! NewOpInstance should be derived from BaseOp"); + + op_instances.push_back(std::make_unique(new_op_instance)); + }); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp new file mode 100644 index 00000000..72160a9d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// aliasing, for commonly used data type +using F64 = double; +using F32 = float; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using EMPTY_TUPLE = ck::Tuple<>; + +using F16_TUPLE = ck::Tuple; +using F16_F16_TUPLE = ck::Tuple; + +using F32_TUPLE = ck::Tuple; + +// GEMM layout +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +// Conv layout +using NWC = ck::tensor_layout::convolution::NWC; +using NHWC = ck::tensor_layout::convolution::NHWC; +using NDHWC = ck::tensor_layout::convolution::NDHWC; + +using KXC = ck::tensor_layout::convolution::KXC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; + +using NWK = ck::tensor_layout::convolution::NWK; +using NHWK = ck::tensor_layout::convolution::NHWK; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +// pointwise functor +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +template +struct DeviceOperationInstanceFactory; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp new file mode 100644 index 00000000..0655fd92 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceBatchedGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp new file mode 100644 index 00000000..9bb8e5ce --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances); + +// Contraction + Bilinear +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>> +{ + using DeviceOp = DeviceContractionMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp new file mode 100644 index 00000000..6eb5b1d0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances); + +// Contraction + Scale +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>> +{ + using DeviceOp = DeviceContractionMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp new file mode 100644 index 00000000..4bb8339c --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp @@ -0,0 +1,292 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d_fwd +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>>& + instances); + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances( + std::vector>>& + instances); + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances( + std::vector>>& + instances); + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances( + std::vector>>& instances); + +// conv2d_fwd +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& + instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances); + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances); + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& + instances); + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances); + +// conv3d_fwd +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>>& instances); + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>>& instances); + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>>& instances); + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceConvFwd> +{ + using DeviceOp = DeviceConvFwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 1 && is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp new file mode 100644 index 00000000..a9cc8b79 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Normalize = ck::tensor_operation::element_wise::Normalize; +using DeviceNormalizeFromMeanMeanSquarePtr = + ck::tensor_operation::device::DeviceElementwisePtr<5, 1, 2, Normalize>; + +void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( + std::vector& instances); + +template +auto get_device_normalize_from_mean_meansquare_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value && is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); + } + + return op_ptrs; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp new file mode 100644 index 00000000..682f5467 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +template +auto get_device_gemm_add_add_mean_squaremean_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp new file mode 100644 index 00000000..55ca8f42 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -0,0 +1,383 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemm> +{ + using DeviceOp = DeviceGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp new file mode 100644 index 00000000..e2cd64b3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>&); + +// GEMM + Add + Add + FastGelu +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp new file mode 100644 index 00000000..37731fde --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +// GEMM + Bilinear +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>> +{ + using DeviceOp = DeviceGemmMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp new file mode 100644 index 00000000..8986a793 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmSplitK> +{ + using DeviceOp = DeviceGemmSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp new file mode 100644 index 00000000..97e9addf --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp new file mode 100644 index 00000000..5fd8c95f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using reduce_configuration_1_instances_blockwise = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_blockwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_blockwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +using deviceReduceBlockWisePtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; + +template +void add_device_reduce_instance_blockwise( + std::vector>& device_op_instances) +{ + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + constexpr bool Indexable = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + constexpr bool OutputIndex = Indexable && UseIndex; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_blockwise{}))>; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise{}))>; + + using ReduceOpInstance = + DeviceReduceMultiBlock; + + device_op_instances.push_back( + std::make_unique(ReduceOpInstance{})); + }); + }); +}; + +#define ADD_BLOCKWISE_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_blockwise( \ + std::vector> & device_op_instances) + +#define ADD_BLOCKWISE_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_BLOCKWISE_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_blockwise( \ + std::vector> & device_op_instances) + +#define ADD_BLOCKWISE_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp new file mode 100644 index 00000000..8d1fed04 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp new file mode 100644 index 00000000..ae7f13ce --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp new file mode 100644 index 00000000..c26e1365 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp new file mode 100644 index 00000000..30064d58 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp new file mode 100644 index 00000000..c9f6a1a5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp new file mode 100644 index 00000000..c598e64c --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp new file mode 100644 index 00000000..cd159499 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp new file mode 100644 index 00000000..bf62f92a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp new file mode 100644 index 00000000..9fc409a0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct ReductionConfiguration_1 +{ + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid Configuration!"); + + static constexpr int BlockSize_ = BlockSize; + static constexpr int MThreadClusterSize_ = MThreadClusterSize; + static constexpr int KThreadClusterSize_ = KThreadClusterSize; +}; + +template +struct ReductionConfiguration_2 +{ + static constexpr int InSrcVectorDim_ = InSrcVectorDim; + static constexpr int InSrcVectorSize_ = InSrcVectorSize; + static constexpr int OutDstVectorSize_ = OutDstVectorSize; + static constexpr int MThreadSliceSize_ = MThreadSliceSize; + static constexpr int KThreadSliceSize_ = KThreadSliceSize; +}; + +#define QUICK_REDUCE_TEST 1 + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp new file mode 100644 index 00000000..a74e92ec --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using reduce_configuration_1_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; + +template +void add_device_reduce_instance_multiblock_atomic_add( + std::vector>& device_op_instances) +{ + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + constexpr bool Indexable = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + constexpr bool OutputIndex = Indexable && UseIndex; + + static_assert(UseIndex == false, + "AtomicAdd can only be used with reduction operations using no index!"); + + constexpr bool op_acceptable = + (ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::MUL || + ReduceOpId == ReduceTensorOp::AVG || ReduceOpId == ReduceTensorOp::NORM1); + + constexpr bool out_type_acceptable = + (std::is_same::value || std::is_same::value); + + if constexpr(!op_acceptable || !out_type_acceptable) + return; + else + { + static_for<0, + std::tuple_size::value, + 1>{}([&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_multiblock_atomic_add{}))>; + + static_for< + 0, + std::tuple_size::value, + 1>{}([&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_multiblock_atomic_add{}))>; + + using ReduceOpInstance = + DeviceReduceMultiBlock; + + device_op_instances.push_back( + std::make_unique(ReduceOpInstance{})); + }); + }); + } +}; + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_multiblock_atomic_add( \ + std::vector> & device_op_instances) + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_multiblock_atomic_add( \ + std::vector> & device_op_instances) + +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp new file mode 100644 index 00000000..3efc5850 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp new file mode 100644 index 00000000..804cba12 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp new file mode 100644 index 00000000..32eb843a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp new file mode 100644 index 00000000..9f2a8924 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp new file mode 100644 index 00000000..bd200699 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp new file mode 100644 index 00000000..6b84b25d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_threadwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_threadwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +using deviceReduceThreadWisePtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; + +template +void add_device_reduce_instance_threadwise( + std::vector>& device_op_instances) +{ + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + constexpr bool Indexable = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + constexpr bool OutputIndex = Indexable && UseIndex; + + using cfg1 = ReductionConfiguration_1<256, 256, 1>; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_threadwise{}))>; + + using ReduceOpInstance = DeviceReduceThreadWise; + + device_op_instances.push_back(std::make_unique(ReduceOpInstance{})); + }); +}; + +#define ADD_THREADWISE_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_threadwise( \ + std::vector> & device_op_instances) + +#define ADD_THREADWISE_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_THREADWISE_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +#define ADD_THREADWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_threadwise( \ + std::vector> & device_op_instances) + +#define ADD_THREADWISE_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_THREADWISE_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ + NumReduceDim) + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp new file mode 100644 index 00000000..5f7f5c7a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp new file mode 100644 index 00000000..3c21b408 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp new file mode 100644 index 00000000..cd116986 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp new file mode 100644 index 00000000..a764735f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp new file mode 100644 index 00000000..7d47c79f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp new file mode 100644 index 00000000..faced808 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp new file mode 100644 index 00000000..111ba7a0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp new file mode 100644 index 00000000..c771f057 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp new file mode 100644 index 00000000..cb4d0172 --- /dev/null +++ b/library/include/ck/library/utility/check_err.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/io.hpp" + +namespace ck { +namespace utils { + +template +typename std::enable_if::value && !std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + err = std::abs(out[i] - ref[i]); + if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out[i] << " != " << ref[i] << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + // TODO: This is a hack. We should have proper specialization for bhalf_t data type. + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double o = type_convert(out[i]); + double r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double o = type_convert(out[i]); + double r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value && !std::is_same::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double atol = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + int64_t o = out[i]; + int64_t r = ref[i]; + err = std::abs(o - r); + + if(err > atol) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << "max err: " << max_err << std::endl; + } + return res; +} + +} // namespace utils +} // namespace ck diff --git a/host/host_tensor/include/conv_common.hpp b/library/include/ck/library/utility/conv_common.hpp similarity index 92% rename from host/host_tensor/include/conv_common.hpp rename to library/include/ck/library/utility/conv_common.hpp index 4bf2c234..6fad9f7d 100644 --- a/host/host_tensor/include/conv_common.hpp +++ b/library/include/ck/library/utility/conv_common.hpp @@ -1,16 +1,9 @@ -#ifndef CONV_COMMON_HPP -#define CONV_COMMON_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "tensor_descriptor.hpp" +#pragma once -enum ConvTensorLayout -{ - NCHW, - NHWC, - CHWN, - NCHWc, - NHWCc -}; +#include "ck/tensor_description/tensor_descriptor.hpp" template +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct ConvParams +{ + ConvParams(); + ConvParams(ck::index_t n_dim, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads); + + ck::index_t num_dim_spatial_; + ck::index_t N_; + ck::index_t K_; + ck::index_t C_; + + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; + std::vector output_spatial_lengths_; + + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + + std::vector input_left_pads_; + std::vector input_right_pads_; + + std::vector GetOutputSpatialLengths() const; + + std::size_t GetFlops() const; + + template + std::size_t GetByte() const + { + // sizeof(InDataType) * (N * C * ) + + // sizeof(WeiDataType) * (K * C * ) + + // sizeof(OutDataType) * (N * K * ); + return sizeof(InDataType) * (N_ * C_ * + std::accumulate(std::begin(input_spatial_lengths_), + std::end(input_spatial_lengths_), + static_cast(1), + std::multiplies())) + + sizeof(WeiDataType) * (K_ * C_ * + std::accumulate(std::begin(filter_spatial_lengths_), + std::end(filter_spatial_lengths_), + static_cast(1), + std::multiplies())) + + sizeof(OutDataType) * (N_ * K_ * + std::accumulate(std::begin(output_spatial_lengths_), + std::end(output_spatial_lengths_), + static_cast(1), + std::multiplies())); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::tensor_operation::device::ConvParams& p); diff --git a/library/include/ck/library/utility/device_memory.hpp b/library/include/ck/library/utility/device_memory.hpp new file mode 100644 index 00000000..5667db7f --- /dev/null +++ b/library/include/ck/library/utility/device_memory.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + std::size_t GetBufferSize(); + void ToDevice(const void* p); + void FromDevice(void* p); + void SetZero(); + template + void SetValue(T x) + { + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp new file mode 100644 index 00000000..6a764427 --- /dev/null +++ b/library/include/ck/library/utility/fill.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace utils { + +template +struct FillUniformDistribution +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } +}; + +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; + +// Workaround for uniform_int_distribution not working as expected. See note above.< +template +struct FillUniformDistributionIntegerValue +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } +}; + +template +struct FillMonotonicSeq +{ + T init_value_{0}; + T step_{1}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, n = init_value_]() mutable { + auto tmp = n; + n += step_; + return tmp; + }); + } +}; + +template +struct FillConstant +{ + T value_{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::fill(first, last, value_); + } +}; + +} // namespace utils +} // namespace ck diff --git a/library/include/ck/library/utility/host_common_util.hpp b/library/include/ck/library/utility/host_common_util.hpp new file mode 100644 index 00000000..31e5571e --- /dev/null +++ b/library/include/ck/library/utility/host_common_util.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { + +namespace host_common { + +template +static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) +{ + std::ofstream outFile(fileName, std::ios::binary); + if(outFile) + { + outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); + outFile.close(); + std::cout << "Write output to file " << fileName << std::endl; + } + else + { + std::cout << "Could not open file " << fileName << " for writing" << std::endl; + } +}; + +template +static inline T getSingleValueFromString(const std::string& valueStr) +{ + std::istringstream iss(valueStr); + + T val; + + iss >> val; + + return (val); +}; + +template +static inline std::vector getTypeValuesFromString(const char* cstr_values) +{ + std::string valuesStr(cstr_values); + + std::vector values; + std::size_t pos = 0; + std::size_t new_pos; + + new_pos = valuesStr.find(',', pos); + while(new_pos != std::string::npos) + { + const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); + + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + pos = new_pos + 1; + new_pos = valuesStr.find(',', pos); + }; + + std::string sliceStr = valuesStr.substr(pos); + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + return (values); +} + +} // namespace host_common +} // namespace ck diff --git a/library/include/ck/library/utility/host_conv.hpp b/library/include/ck/library/utility/host_conv.hpp new file mode 100644 index 00000000..8348a308 --- /dev/null +++ b/library/include/ck/library/utility/host_conv.hpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "host_tensor.hpp" +#include "conv_common.hpp" + +template +void host_conv_nchw_kcyx_nkhw(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) +{ + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += ck::type_convert(in(n, c, hi, wi)) * + ck::type_convert(wei(k, c, y, x)); + } + } + } + } + out(n, k, ho, wo) = ck::type_convert(v); + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); +} + +template +void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + const auto Di = in.mDesc.GetLengths()[1]; + const auto Hi = in.mDesc.GetLengths()[2]; + const auto Wi = in.mDesc.GetLengths()[3]; + const auto Z = wei.mDesc.GetLengths()[1]; + const auto Y = wei.mDesc.GetLengths()[2]; + const auto X = wei.mDesc.GetLengths()[3]; + const auto C = wei.mDesc.GetLengths()[4]; + + auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) { + // do__ must be converted to signed integer, otherwise zmin might be wrong in cases + // negative values. + const int do_ = static_cast(do_tmp); + const int ho = static_cast(ho_tmp); + const int wo = static_cast(wo_tmp); + const int zmin = + std::max(0, + (in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) / + conv_dilations[I0]); + const int ymin = + std::max(0, + (in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) / + conv_dilations[I1]); + const int xmin = + std::max(0, + (in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) / + conv_dilations[I2]); + const int zmax = + std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]); + const int ymax = + std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]); + const int xmax = + std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]); + const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0]; + const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1]; + const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2]; + + double v = 0; + + const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C; + const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C; + + int di = di_min; + for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0]) + { + const TIn* in_n_di = in_n + di * Hi * Wi * C; + const TWei* wei_k_z = wei_k + z * Y * X * C; + int hi = hi_min; + + for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1]) + { + const TIn* in_n_di_hi = in_n_di + hi * Wi * C; + const TWei* wei_k_z_y = wei_k_z + y * X * C; + int wi = wi_min; + + for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2]) + { + const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C; + const TWei* wei_k_z_y_x = wei_k_z_y + x * C; + + for(int c = 0; c < C; ++c) + { + v += static_cast(in_n_di_hi_wi[c]) * + static_cast(wei_k_z_y_x[c]); + } + } + } + } + + out(n, do_, ho, wo, k) = v; + }; + + make_ParallelTensorFunctor(f_ndhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3], + out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4); +} diff --git a/library/include/ck/library/utility/host_gemm.hpp b/library/include/ck/library/utility/host_gemm.hpp new file mode 100644 index 00000000..44036d02 --- /dev/null +++ b/library/include/ck/library/utility/host_gemm.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "host_tensor.hpp" + +template +void host_gemm_mk_kn_mn(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) +{ + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a_m_k.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + a_element_op(v_a, static_cast(a_m_k(m, k))); + b_element_op(v_b, static_cast(b_k_n(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + c_element_op(v_c, v_acc); + + c_m_n(m, n) = v_c; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, + c_m_n.mDesc.GetLengths()[0], + c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); +} diff --git a/library/include/ck/library/utility/host_reduction.hpp b/library/include/ck/library/utility/host_reduction.hpp new file mode 100644 index 00000000..f02ebcd7 --- /dev/null +++ b/library/include/ck/library/utility/host_reduction.hpp @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +static void get_all_indexes(const std::array& dimLengths, + std::vector>& indexes) +{ + static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); + + if constexpr(NDim == 1) + { + for(size_t i = 0; i < dimLengths[0]; i++) + { + std::array index{i}; + + indexes.push_back(index); + }; + } + else + { + std::array partial_dim_lengths; + + for(int i = 0; i < NDim - 1; i++) + partial_dim_lengths[i] = dimLengths[i + 1]; + + std::vector> partial_indexes; + + get_all_indexes(partial_dim_lengths, partial_indexes); + + for(size_t i = 0; i < dimLengths[0]; i++) + for(const auto& index : partial_indexes) + { + std::array extIndex; + + extIndex[0] = i; + + for(int k = 0; k < NDim - 1; k++) + extIndex[k + 1] = index[k]; + + indexes.push_back(extIndex); + }; + }; +}; + +template +static size_t get_offset_from_index(const std::array& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += strides[i] * index[i]; + + return (offset); +}; + +template +static size_t get_offset_from_index(const std::vector& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += strides[i] * index[i]; + + return (offset); +}; + +template +struct ReductionHost +{ + using IndexDataType = int32_t; + + static constexpr int NumInvariantDim = Rank - NumReduceDim; + + std::vector outStrides; + std::vector invariantDims; + std::vector reduceDims; + + IndexDataType divider; + std::array reduceLengths; + std::array reduceStrides; + std::array invariantLengths; + std::array invariantStrides; + + std::vector> reduce_dim_indexes; + std::vector> invariant_dim_indexes; + + ReductionHost(HostTensorDescriptor& inDesc, + HostTensorDescriptor& outDesc, + const std::vector& invariantDims_, + const std::vector& reduceDims_) + { + // this->outLengths = to_int_vector(outDesc.GetLengths()); + this->outStrides = outDesc.GetStrides(); + + this->invariantDims = invariantDims_; + this->reduceDims = reduceDims_; + + int product = 1; + + for(int i = 0; i < NumReduceDim; i++) + { + reduceLengths[i] = inDesc.GetLengths()[reduceDims[i]]; + reduceStrides[i] = inDesc.GetStrides()[reduceDims[i]]; + product *= inDesc.GetLengths()[reduceDims[i]]; + }; + + divider = product; + + for(int i = 0; i < NumInvariantDim; i++) + { + invariantLengths[i] = inDesc.GetLengths()[invariantDims[i]]; + invariantStrides[i] = inDesc.GetStrides()[invariantDims[i]]; + }; + + reduce_dim_indexes.clear(); + get_all_indexes(reduceLengths, reduce_dim_indexes); + + if constexpr(NumInvariantDim > 0) + { + invariant_dim_indexes.clear(); + get_all_indexes(invariantLengths, invariant_dim_indexes); + }; + }; + + void Run(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) + { + if constexpr(OutputIndex) + { + RunImpl_with_index( + alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op); + } + else + { + RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op); + }; + }; + + void RunImpl_with_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) + { + using ck::float_equal_one; + using ck::float_equal_zero; + using ck::type_convert; + + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + + if constexpr(NumInvariantDim == 0) + { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); + + auto currVal = type_convert(in_data[offset_reduce]); + + in_elementwise_op(currVal, currVal); + + auto currIndex = static_cast(i); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[0]) * type_convert(beta); + + out_data[0] = type_convert(accuVal); + out_indices[0] = accuIndex; + } + else + { + auto thread_reduce_func = [&](auto invariant_index) { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + auto offset_invariant = + get_offset_from_index(invariantStrides, invariant_index); + + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); + + auto currVal = + type_convert(in_data[offset_invariant + offset_reduce]); + + in_elementwise_op(currVal, currVal); + + auto currIndex = static_cast(i); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + auto dst_offset = + get_offset_from_index(outStrides, invariant_index); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[dst_offset]) * + type_convert(beta); + + out_data[dst_offset] = type_convert(accuVal); + out_indices[dst_offset] = accuIndex; + }; + + std::size_t num_thread = 1; + std::size_t work_per_thread = + (invariant_dim_indexes.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = + std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + thread_reduce_func(invariant_dim_indexes[iw]); + } + }; + + threads[it] = joinable_thread(f); + } + }; + }; + + void RunImpl_no_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) + { + using ck::float_equal_one; + using ck::float_equal_zero; + using ck::type_convert; + + using Accumulation = + ck::detail::AccumulateWithNanCheck; + + if constexpr(NumInvariantDim == 0) + { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + + for(const auto& reduce_index : reduce_dim_indexes) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_index); + + auto currVal = type_convert(in_data[offset_reduce]); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[0]) * type_convert(beta); + + out_data[0] = type_convert(accuVal); + } + else + { + auto thread_reduce_func = [&](auto invariant_index) { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + + auto offset_invariant = + get_offset_from_index(invariantStrides, invariant_index); + + for(const auto& reduce_index : reduce_dim_indexes) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_index); + + auto currVal = + type_convert(in_data[offset_invariant + offset_reduce]); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + auto dst_offset = + get_offset_from_index(outStrides, invariant_index); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[dst_offset]) * + type_convert(beta); + + out_data[dst_offset] = type_convert(accuVal); + }; + + std::size_t num_thread = 1; + std::size_t work_per_thread = + (invariant_dim_indexes.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = + std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + thread_reduce_func(invariant_dim_indexes[iw]); + } + }; + + threads[it] = joinable_thread(f); + } + }; + }; +}; diff --git a/host/host_tensor/include/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp similarity index 56% rename from host/host_tensor/include/host_tensor.hpp rename to library/include/ck/library/utility/host_tensor.hpp index 06aed0a0..b31dbd8b 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,7 @@ -#ifndef HOST_TENSOR_HPP -#define HOST_TENSOR_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once #include #include @@ -9,6 +11,8 @@ #include #include +#include "ck/utility/data_type.hpp" + template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) { @@ -39,20 +43,6 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) return os; } -typedef enum -{ - Half = 0, - Float = 1, -} DataType_t; - -template -struct DataType; - -template <> -struct DataType : std::integral_constant -{ -}; - template auto call_f_unpack_args_impl(F f, T args, std::index_sequence) { @@ -83,13 +73,13 @@ auto construct_f_unpack_args(F, T args) struct HostTensorDescriptor { - HostTensorDescriptor() = delete; + HostTensorDescriptor() = default; template - HostTensorDescriptor(std::vector lens); + HostTensorDescriptor(const std::vector& lens); template - HostTensorDescriptor(std::vector lens, std::vector strides); + HostTensorDescriptor(const std::vector& lens, const std::vector& strides); void CalculateStrides(); @@ -120,11 +110,34 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } + std::size_t GetOffsetFromMultiIndex(std::vector iss) const + { + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + private: std::vector mLens; std::vector mStrides; }; +template +HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, + const New2Old& new2old) +{ + std::vector new_lengths(a.GetNumOfDimension()); + std::vector new_strides(a.GetNumOfDimension()); + + for(std::size_t i = 0; i < a.GetNumOfDimension(); i++) + { + new_lengths[i] = a.GetLengths()[new2old[i]]; + new_strides[i] = a.GetStrides()[new2old[i]]; + } + + return HostTensorDescriptor(new_lengths, new_strides); +} + struct joinable_thread : std::thread { template @@ -165,7 +178,7 @@ struct ParallelTensorFunctor { std::array indices; - for(int idim = 0; idim < NDIM; ++idim) + for(std::size_t idim = 0; idim < NDIM; ++idim) { indices[idim] = i / mStrides[idim]; i -= indices[idim] * mStrides[idim]; @@ -174,7 +187,7 @@ struct ParallelTensorFunctor return indices; } - void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const + void operator()(std::size_t num_thread = 1) const { std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; @@ -223,6 +236,80 @@ struct Tensor Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} + template + Tensor CopyAsType() + { + Tensor ret(mDesc); + for(size_t i = 0; i < mData.size(); i++) + { + ret.mData[i] = static_cast(mData[i]); + } + return ret; + } + + Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} + + Tensor& operator=(const Tensor& other) + { + mDesc = other.mDesc; + mData = other.mData; + return *this; + } + + void SetZero() + { + for(auto& v : mData) + { + v = T{0}; + } + } + + template + void ForEach_impl(F&& f, std::vector& idx, size_t rank) + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(F&& f) + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(const F&& f) const + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + template void GenerateTensorValue(G g, std::size_t num_thread = 1) { @@ -255,6 +342,18 @@ struct Tensor mDesc.GetLengths()[3])(num_thread); break; } + case 5: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) { + (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4])(num_thread); + break; + } default: throw std::runtime_error("unspported dimension"); } } @@ -271,6 +370,16 @@ struct Tensor return mData[mDesc.GetOffsetFromMultiIndex(is...)]; } + T& operator()(std::vector idx) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + const T& operator()(std::vector idx) const + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + typename std::vector::iterator begin() { return mData.begin(); } typename std::vector::iterator end() { return mData.end(); } @@ -284,39 +393,64 @@ struct Tensor }; template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) + : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) - : mLens(lens), mStrides(strides) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens, + const std::vector& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); - +#if 1 +// FIXME: remove template -void check_error(const Tensor& ref, const Tensor& result) +float check_error(const Tensor& ref, const Tensor& result) { - float error = 0; - float max_diff = -1; - float ref_value = 0, result_value = 0; - for(int i = 0; i < ref.mData.size(); ++i) + float l1_error = 0; + float linf_error = -1; + float linf_rel_error = -1; + + float linf_ref_value = 0, linf_result_value = 0; + float linf_rel_ref_value = 0, linf_rel_result_value = 0; + + constexpr float eps = 1e-10; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) { - error += std::abs(double(ref.mData[i]) - double(result.mData[i])); - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) + float ref_v = ck::type_convert(ref.mData[i]); + float result_v = ck::type_convert(result.mData[i]); + + float diff = std::abs(ref_v - result_v); + float rel_diff = diff / std::max(std::abs(ref_v), eps); + + l1_error += diff; + + if(linf_error < diff) { - max_diff = diff; - ref_value = ref.mData[i]; - result_value = result.mData[i]; + linf_error = diff; + linf_ref_value = ref_v; + linf_result_value = result_v; + } + + if(linf_rel_error < rel_diff) + { + linf_rel_error = rel_diff; + linf_rel_ref_value = ref_v; + linf_rel_result_value = result_v; } } - std::cout << "error: " << error << std::endl; - std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; -} + std::cout << "Absolute Error L1 Norm (sum of abs diff): " << l1_error << std::endl; + std::cout << "Absolute Error L-inf Norm (max abs diff): " << linf_error << ", ref " + << linf_ref_value << ", result " << linf_result_value << std::endl; + std::cout << "Relative Error L-inf Norm (max relative abs diff): " << linf_rel_error << ", ref " + << linf_rel_ref_value << ", result " << linf_rel_result_value << std::endl; + return linf_error; +} #endif diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp new file mode 100644 index 00000000..e0bd4991 --- /dev/null +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" + +template +struct GeneratorTensor_0 +{ + template + T operator()(Is...) + { + return T{0}; + } +}; + +template +struct GeneratorTensor_1 +{ + T value = 1; + + template + T operator()(Is...) + { + return value; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bhalf_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + int8_t operator()(Is...) + { + return value; + } +}; + +template +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + T operator()(Is...) + { + return static_cast((std::rand() % (max_value - min_value)) + min_value); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + int8_t operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +template +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + T operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + return static_cast(min_value + tmp * (max_value - min_value)); + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + +struct GeneratorTensor_Checkboard +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {static_cast(Xs)...}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, ck::index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + return dims[Dim]; + } +}; diff --git a/library/include/ck/library/utility/io.hpp b/library/include/ck/library/utility/io.hpp new file mode 100644 index 00000000..c91c0ad1 --- /dev/null +++ b/library/include/ck/library/utility/io.hpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} diff --git a/library/include/ck/library/utility/op_instance_engine.hpp b/library/include/ck/library/utility/op_instance_engine.hpp new file mode 100644 index 00000000..346b69f1 --- /dev/null +++ b/library/include/ck/library/utility/op_instance_engine.hpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/functional2.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace utils { + +struct ProfileBestConfig +{ + std::string best_op_name; + float best_avg_time = std::numeric_limits::max(); + float best_tflops = std::numeric_limits::max(); + float best_gb_per_sec = std::numeric_limits::max(); +}; + +/** + * @brief This class describes an operation instance(s). + * + * Op instance defines a particular specializations of operator + * template. Thanks to this specific input/output data types, data + * layouts and modifying elementwise operations it is able to create + * it's input/output tensors, provide pointers to instances which + * can execute it and all operation specific parameters. + */ +template +class OpInstance +{ + public: + template + using TensorPtr = std::unique_ptr>; + using InTensorsTuple = std::tuple...>; + using DeviceMemPtr = std::unique_ptr; + using DeviceBuffers = std::vector; + + OpInstance() = default; + OpInstance(const OpInstance&) = default; + OpInstance& operator=(const OpInstance&) = default; + virtual ~OpInstance(){}; + + virtual InTensorsTuple GetInputTensors() const = 0; + virtual TensorPtr GetOutputTensor() const = 0; + virtual std::unique_ptr + MakeInvokerPointer(tensor_operation::device::BaseOperator*) const = 0; + virtual std::unique_ptr + MakeArgumentPointer(tensor_operation::device::BaseOperator*, + const DeviceBuffers&, + const DeviceMemPtr&) const = 0; + virtual std::size_t GetFlops() const = 0; + virtual std::size_t GetBtype() const = 0; +}; + +/** + * @brief A generic operation instance run engine. + */ +template +class OpInstanceRunEngine +{ + public: + using OpInstanceT = OpInstance; + template + using TensorPtr = std::unique_ptr>; + using DeviceMemPtr = std::unique_ptr; + using InTensorsTuple = std::tuple...>; + using DeviceBuffers = std::vector; + using InArgsTypesTuple = std::tuple; + + OpInstanceRunEngine() = delete; + + template > + OpInstanceRunEngine(const OpInstanceT& op_instance, + const ReferenceOp& reference_op = ReferenceOp{}, + bool do_verification = true) + : op_instance_{op_instance} + { + in_tensors_ = op_instance_.GetInputTensors(); + out_tensor_ = op_instance_.GetOutputTensor(); + + if constexpr(std::is_invocable_v&..., + Tensor&>) + { + if(do_verification) + { + ref_output_ = op_instance_.GetOutputTensor(); + CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + } + } + AllocateDeviceInputTensors(std::make_index_sequence{}); + out_device_buffer_ = + std::make_unique(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace()); + out_device_buffer_->SetZero(); + } + + virtual ~OpInstanceRunEngine(){}; + + template + bool Test(const std::vector& op_ptrs) + { + bool res{true}; + for(auto& op_ptr : op_ptrs) + { + auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); + auto argument = op_instance_.MakeArgumentPointer( + op_ptr.get(), in_device_buffers_, out_device_buffer_); + if(op_ptr->IsSupportedArgument(argument.get())) + { + std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl; + invoker->Run(argument.get()); + out_device_buffer_->FromDevice(out_tensor_->mData.data()); + if(!ref_output_) + { + throw std::runtime_error( + "OpInstanceRunEngine::Test: Reference value not availabe." + " You have to provide reference function."); + } + // TODO: enable flexible use of custom check_error functions + bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData); + std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl; + res = res && inst_res; + out_device_buffer_->SetZero(); + } + else + { + std::cout << "Given conv problem is not supported by instance: \n\t>>>>" + << op_ptr->GetTypeString() << std::endl; + } + } + return res; + } + + template + ProfileBestConfig Profile(const std::vector& op_ptrs, + bool time_kernel = false, + bool do_verification = false, + bool do_log = false) + { + ProfileBestConfig best_config; + + for(auto& op_ptr : op_ptrs) + { + auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); + auto argument = op_instance_.MakeArgumentPointer( + op_ptr.get(), in_device_buffers_, out_device_buffer_); + if(op_ptr->IsSupportedArgument(argument.get())) + { + std::string op_name = op_ptr->GetTypeString(); + float avg_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flops = op_instance_.GetFlops(); + std::size_t num_btype = op_instance_.GetBtype(); + float tflops = static_cast(flops) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(avg_time < best_config.best_avg_time) + { + best_config.best_op_name = op_name; + best_config.best_tflops = tflops; + best_config.best_gb_per_sec = gb_per_sec; + best_config.best_avg_time = avg_time; + } + + if(do_verification) + { + out_device_buffer_->FromDevice(out_tensor_->mData.data()); + if(!ref_output_) + { + throw std::runtime_error( + "OpInstanceRunEngine::Profile: Reference value not availabe." + " You have to provide reference function."); + } + // TODO: enable flexible use of custom check_error functions + CheckErr(out_tensor_->mData, ref_output_->mData); + + if(do_log) {} + } + out_device_buffer_->SetZero(); + } + } + return best_config; + } + + void SetAtol(double a) { atol_ = a; } + void SetRtol(double r) { rtol_ = r; } + + private: + template + void CallRefOpUnpackArgs(const F& f, std::index_sequence) const + { + f(*std::get(in_tensors_)..., *ref_output_); + } + + template + void AllocateDeviceInputTensors(std::index_sequence) + { + (AllocateDeviceInputTensorsImpl(), ...); + } + + template + void AllocateDeviceInputTensorsImpl() + { + const auto& ts = std::get(in_tensors_); + in_device_buffers_ + .emplace_back( + std::make_unique(sizeof(std::tuple_element_t) * + ts->mDesc.GetElementSpace())) + ->ToDevice(ts->mData.data()); + } + + static constexpr std::size_t kNInArgs_ = std::tuple_size_v; + const OpInstanceT& op_instance_; + double rtol_{1e-5}; + double atol_{1e-8}; + + InTensorsTuple in_tensors_; + TensorPtr out_tensor_; + TensorPtr ref_output_; + + DeviceBuffers in_device_buffers_; + DeviceMemPtr out_device_buffer_; + + template + bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const + { + return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_); + } +}; + +} // namespace utils +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt new file mode 100644 index 00000000..5de5e7d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -0,0 +1,100 @@ +function(add_instance_library INSTANCE_NAME) + message("adding instance ${INSTANCE_NAME}") + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + target_compile_features(${INSTANCE_NAME} PUBLIC) + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +endfunction(add_instance_library INSTANCE_NAME) + +add_subdirectory(gemm) +add_subdirectory(gemm_splitk) +add_subdirectory(gemm_bilinear) +add_subdirectory(gemm_add_add_fastgelu) +add_subdirectory(gemm_reduce) +add_subdirectory(gemm_bias_add_reduce) +add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) +add_subdirectory(grouped_gemm) +add_subdirectory(contraction_scale) +add_subdirectory(contraction_bilinear) +add_subdirectory(conv1d_fwd) +add_subdirectory(conv2d_fwd) +add_subdirectory(conv3d_fwd) +add_subdirectory(conv2d_fwd_bias_relu) +add_subdirectory(conv2d_fwd_bias_relu_add) +add_subdirectory(conv2d_bwd_data) +add_subdirectory(convnd_bwd_data) +add_subdirectory(conv2d_bwd_weight) +add_subdirectory(convnd_bwd_weight) +add_subdirectory(reduce) +add_subdirectory(normalization) +add_subdirectory(elementwise) + +add_library(device_operations STATIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ +) +add_library(composablekernels::device_operations ALIAS device_operations) + + +set(DEV_OPS_INC_DIRS + ${PROJECT_SOURCE_DIR}/include/ck/ + ${PROJECT_SOURCE_DIR}/library/include/ck/ +) + +target_compile_features(device_operations PUBLIC) +set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(device_operations PUBLIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ +) + +#once new arches are enabled make this an option on the main cmake file +# and pass down here to be exported +target_compile_options(device_operations PRIVATE + --offload-arch=gfx908 + --offload-arch=gfx90a +) + +# install(TARGETS device_operations LIBRARY DESTINATION lib) +rocm_install(TARGETS device_operations + EXPORT device_operationsTargets) + +rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) +rocm_install(EXPORT device_operationsTargets + FILE composable_kerneldevice_operationsTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt new file mode 100644 index 00000000..016c85f6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -0,0 +1,26 @@ +#device_batched_gemm_instance +set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE + device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp; + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp; +) + +add_library(device_batched_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) +# target_compile_features(device_batched_gemm_instance PUBLIC) +set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +# install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_batched_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000..1cc92524 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000..c35a8d6d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000..1bbedebe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000..2ceaa20b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000..36962857 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000..f79d3041 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000..8290e756 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000..f3345eba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000..8b671dfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000..646450e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000..1696d297 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000..3dbd6370 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000..0691f4f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000..efd49bf1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000..9c3d6609 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000..330d1396 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000..0606df01 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,12 @@ +set(DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +) + +add_instance_library(device_batched_gemm_reduce_instance OBJECT ${DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE}) +target_compile_features(device_batched_gemm_reduce_instance PUBLIC) +set_target_properties(device_batched_gemm_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +clang_tidy_check(device_batched_gemm_reduce_instance) + diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 00000000..1c4541af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 00000000..07eb9b94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 00000000..2d9cee47 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 00000000..03ce1ce0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt new file mode 100644 index 00000000..fb38c645 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_contraction_bilinear_instance +set(DEVICE_CONTRACTION_BILINEAR_INSTANCE_SOURCE + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +) + +add_library(device_contraction_bilinear_instance OBJECT ${DEVICE_CONTRACTION_BILINEAR_INSTANCE_SOURCE}) +set_target_properties(device_contraction_bilinear_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_contraction_bilinear_instance) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp new file mode 100644 index 00000000..036818ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp new file mode 100644 index 00000000..b277fb86 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp new file mode 100644 index 00000000..c03ce0b1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp new file mode 100644 index 00000000..ab56c4c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_TUPLE = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt new file mode 100644 index 00000000..32806757 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_contraction_scale_instance +set(DEVICE_CONTRACTION_SCALE_INSTANCE_SOURCE + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +) + +add_library(device_contraction_scale_instance OBJECT ${DEVICE_CONTRACTION_SCALE_INSTANCE_SOURCE}) +set_target_properties(device_contraction_scale_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_contraction_scale_instance) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp new file mode 100644 index 00000000..7f49a986 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp new file mode 100644 index 00000000..45ffa63c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp new file mode 100644 index 00000000..cc63b06a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp new file mode 100644 index 00000000..ce11f255 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using EMPTY_TUPLE = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, EMPTY_TUPLE, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt new file mode 100644 index 00000000..77aa6198 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt @@ -0,0 +1,14 @@ +# device_conv1d_fwd_instance +set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE + device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp; +) + +add_library(device_conv1d_fwd_instance OBJECT ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) +# target_compile_features(device_conv1d_fwd_instance PUBLIC) +set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +# install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) + +clang_tidy_check(device_conv1d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 00000000..24beb28b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using BF16 = bhalf_t; + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 00000000..30522a0d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 00000000..e03ce90e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp new file mode 100644 index 00000000..d090daa9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 00000000..d7882a7d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_conv2d_bwd_data_instance +set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + +add_library(device_conv2d_bwd_data_instance OBJECT ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_bwd_data_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000..3b716d64 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..5978ffcd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if !CK_WORKAROUND_SWDEV_325164 + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#endif + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000..42e80be1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000..ff15c023 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000..7d3c57b2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,11 @@ +# device_conv2d_bwd_weight_instance +set(DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; +) +add_library(device_conv2d_bwd_weight_instance OBJECT ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE}) +target_compile_features(device_conv2d_bwd_weight_instance PUBLIC) +set_target_properties(device_conv2d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +rocm_install(TARGETS device_conv2d_bwd_weight_instance) + +clang_tidy_check(device_conv2d_bwd_weight_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..ea9fb8c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000..744f2f91 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt new file mode 100644 index 00000000..5f3b416a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -0,0 +1,16 @@ +# device_conv2d_fwd_instance +set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + +add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +clang_tidy_check(device_conv2d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..55496f2c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000..80754f94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..7b769949 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000..cf2d451c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000..f8ea7a20 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000..638aed6c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..6c095472 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000..6d656105 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000..634f3128 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 00000000..ad66c73b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,8 @@ +# device_conv2d_fwd_bias_relu_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; +) +add_library(device_conv2d_fwd_bias_relu_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_fwd_bias_relu_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..ca0f9c81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt new file mode 100644 index 00000000..36b1f6c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -0,0 +1,8 @@ +# device_conv2d_fwd_bias_relu_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) +add_library(device_conv2d_fwd_bias_relu_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv2d_fwd_bias_relu_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..91aa9182 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt new file mode 100644 index 00000000..91a299c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_conv3d_fwd_instance +set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; +) +add_library(device_conv3d_fwd_instance OBJECT ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) +target_compile_features(device_conv3d_fwd_instance PUBLIC) +set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_conv3d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 00000000..3c57106d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using BF16 = bhalf_t; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 00000000..433d2b9b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 00000000..e4b0b161 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f32_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp new file mode 100644 index 00000000..f509afb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##########################| NumDim| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########################| | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNdFwdNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_int8_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt new file mode 100644 index 00000000..dae633b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,22 @@ +# device_convnd_bwd_data_instance +set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp; + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; +) + +add_library(device_convnd_bwd_data_instance OBJECT ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE}) +target_compile_features(device_convnd_bwd_data_instance PUBLIC) +set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +rocm_install(TARGETS device_convnd_bwd_data_instance) + +clang_tidy_check(device_convnd_bwd_data_instance) diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 00000000..a449a905 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 00000000..fb976740 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 00000000..e8f2a45b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp new file mode 100644 index 00000000..6aad1f02 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + #if 1 + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + #endif + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000..010291cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..e7e14717 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000..357ddabd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 00000000..3eadb0bd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 00000000..6b5f71ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | ./ | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 00000000..214aea28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 00000000..c3e8b5e8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp new file mode 100644 index 00000000..9142b804 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances = + std::tuple< + // clang-format off + //#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //##############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000..7272163f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/CMakeLists.txt @@ -0,0 +1,19 @@ +#device_convnd_bwd_weight_instance +set(DEVICE_CONVND_BWD_WEIGHT_INSTANCE_SOURCE + device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp; + device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; +) + +add_library(device_convnd_bwd_weight_instance OBJECT ${DEVICE_CONVND_BWD_WEIGHT_INSTANCE_SOURCE}) +target_compile_features(device_convnd_bwd_weight_instance PUBLIC) +set_target_properties(device_convnd_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +rocm_install(TARGETS device_convnd_bwd_weight_instance) + +clang_tidy_check(device_convnd_bwd_weight_instance) diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 00000000..c8aae435 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 00000000..6e4964ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 00000000..ed25442d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_c_shuffle_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 00000000..3a0dfeb6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 00000000..025c7c86 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_default_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000..cde50d77 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f32_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk_f32_default_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 00000000..1e2ad43a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 00000000..647a8982 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f16_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f16_default_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 00000000..40754a09 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_weight/device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f32_default_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths|ScalarPerVector| + //#################################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_conv3d_bwd_weight_xdl_c_shuffle_ndhwc_kzyxc_ndhwk_f32_default_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt new file mode 100644 index 00000000..465ba4e9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt @@ -0,0 +1,10 @@ +set(DEVICE_ELEMENTWISE_INSTANCE_SOURCE + device_normalize_instance.cpp +) + +add_instance_library(device_elementwise_instance ${DEVICE_ELEMENTWISE_INSTANCE_SOURCE}) + +target_compile_features(device_elementwise_instance PUBLIC) +set_target_properties(device_elementwise_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_elementwise_instance) diff --git a/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp new file mode 100644 index 00000000..12f7901c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using inputType = F16; +using MeanType = F32; +using SquareMeanType = F32; +using GammaDataType = F16; +using BetaDataType = F16; +using outputType = F16; + +using Normalize = ck::tensor_operation::element_wise::Normalize; +using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple< + // clang-format off + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + //###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector| + Device5AryElementwise, + Device5AryElementwise, + Device5AryElementwise, + Device5AryElementwise + // clang-format on + >; + +void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt new file mode 100644 index 00000000..ce66b56a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -0,0 +1,48 @@ +set(DEVICE_GEMM_INSTANCE_SOURCE + device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp; + device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp; + device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp; + device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_instance PUBLIC) +set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..41efbdcc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..0e6d6239 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..bc2186e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..e2000afb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000..267e3d76 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000..f8bb758b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..54bb6810 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..1ce46ec7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 00000000..f18adfee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 00000000..91277b54 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..a56d9d2c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..63794ac3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..16037f70 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..9ce9dc48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..83b01e26 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..2a364511 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..938c99cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..7066be07 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..39b2e73c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..b4b8cc33 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..8f0996c3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000..5c7e7d35 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000..45ae6c51 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..455d786f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..5667bce3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 00000000..ee88c9a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 00000000..35405578 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 16, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 16, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..a1090695 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..be8de8be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..5fee5384 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..4363bfe9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..544eb02f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..8ce8eb08 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000..b99c023d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000..99a2383c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..8794275d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..4b62cec6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp new file mode 100644 index 00000000..a02763bc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f64_f64_f64_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp new file mode 100644 index 00000000..1275197f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f64_f64_f64_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..d763c68f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..e52e3ff6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 00000000..789c5b62 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,14 @@ +# device_gemm_add_add_fastgelu_instance +set(DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; +) + +add_library(device_gemm_add_add_fastgelu_instance OBJECT ${DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_add_add_fastgelu_instance PUBLIC) +set_target_properties(device_gemm_add_add_fastgelu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_add_add_fastgelu_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..e00a66c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..a5f39893 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..8e2b5cf6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..e28889a2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16_TUPLE, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt new file mode 100644 index 00000000..85a7f3f0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -0,0 +1,13 @@ +set(DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) + +add_library(device_gemm_bias_add_reduce_instance OBJECT ${DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_bias_add_reduce_instance PUBLIC) +set_target_properties(device_gemm_bias_add_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bias_add_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000..aec29f2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000..9ab8e707 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..31377ef8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..d313fc36 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt new file mode 100644 index 00000000..e6c93da8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -0,0 +1,12 @@ +# device_gemm_bilinear_instance +set(DEVICE_GEMM_BILINEAR_INSTANCE_SOURCE + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; +) + +add_library(device_gemm_bilinear_instance OBJECT ${DEVICE_GEMM_BILINEAR_INSTANCE_SOURCE}) +set_target_properties(device_gemm_bilinear_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_bilinear_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..4b8777a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K Padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..589e4bf6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K Padding + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..d18b7c26 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..29763ea4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_TUPLE = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // M/N/N padding + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_TUPLE, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000..5fbdc28d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -0,0 +1,10 @@ +set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) + +add_instance_library(device_gemm_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE}) +rocm_install(TARGETS device_gemm_reduce_instance) +clang_tidy_check(device_gemm_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000..f32303db --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000..82acbcce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..978a4cb3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..a067449f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt new file mode 100644 index 00000000..3700ddf1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -0,0 +1,15 @@ +set(DEVICE_GEMM_SPLITK_INSTANCE_SOURCE + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_splitk_instance OBJECT ${DEVICE_GEMM_SPLITK_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_splitk_instance PUBLIC) +set_target_properties(device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..da59b91f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..aa65e134 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..32b229c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..004143af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 00000000..051ff652 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 00000000..5d3cbf89 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..9a9b05a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //###################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 32, 256, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 16, 256, 4, 4, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 16, 128, 4, 4, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..50dc9305 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000..4d1115ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -0,0 +1,15 @@ +# device_grouped_gemm_instance +set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) + +target_compile_features(device_grouped_gemm_instance PUBLIC) +set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +rocm_install(TARGETS device_grouped_gemm_instance) + +clang_tidy_check(device_grouped_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 00000000..ebc4cc95 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 00000000..e604f15e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000..1b7ecb58 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000..65c88817 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt new file mode 100644 index 00000000..a6ae07ba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -0,0 +1,10 @@ +# device_normalization_instance +set(DEVICE_NORMALIZATION_INSTANCE_SOURCE + device_softmax_f32_f32_instance.cpp + device_softmax_f16_f16_instance.cpp +) + +add_library(device_normalization_instance OBJECT ${DEVICE_NORMALIZATION_INSTANCE_SOURCE}) +set_target_properties(device_normalization_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_normalization_instance) diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp new file mode 100644 index 00000000..8465baa1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_softmax_f16_f16_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + DeviceSoftmax, // fallback kernel + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax + // clang-format on + >; + +void add_device_softmax_f16_f16_rank3_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 2>{}); +} + +void add_device_softmax_f16_f16_rank4_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp new file mode 100644 index 00000000..73ecf747 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using device_softmax_f32_f32_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + DeviceSoftmax, // fallback kernel + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax, + DeviceSoftmax + // clang-format on + >; + +void add_device_softmax_f32_f32_rank3_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 2>{}); +} + +void add_device_softmax_f32_f32_rank4_instances(std::vector& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt new file mode 100644 index 00000000..d566796c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt @@ -0,0 +1,29 @@ +# device_reduce_instance +set(DEVICE_REDUCE_INSTANCE_SOURCE + device_reduce_instance_blockwise_f16_f16_f16.cpp; + device_reduce_instance_blockwise_f16_f32_f16.cpp; + device_reduce_instance_blockwise_f32_f32_f32.cpp; + device_reduce_instance_blockwise_f32_f64_f32.cpp; + device_reduce_instance_blockwise_f64_f64_f64.cpp; + device_reduce_instance_blockwise_i8_i32_i8.cpp; + device_reduce_instance_blockwise_i8_i8_i8.cpp; + device_reduce_instance_blockwise_b16_f32_b16.cpp; + device_reduce_instance_threadwise_f16_f16_f16.cpp; + device_reduce_instance_threadwise_f16_f32_f16.cpp; + device_reduce_instance_threadwise_f32_f32_f32.cpp; + device_reduce_instance_threadwise_f32_f64_f32.cpp; + device_reduce_instance_threadwise_f64_f64_f64.cpp; + device_reduce_instance_threadwise_i8_i32_i8.cpp; + device_reduce_instance_threadwise_i8_i8_i8.cpp; + device_reduce_instance_threadwise_b16_f32_b16.cpp; + device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp; + device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp; +) + +add_library(device_reduce_instance OBJECT ${DEVICE_REDUCE_INSTANCE_SOURCE}) +set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp new file mode 100644 index 00000000..c97efbc9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp new file mode 100644 index 00000000..5e73b3d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp new file mode 100644 index 00000000..93d3e270 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp new file mode 100644 index 00000000..38800ddd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp new file mode 100644 index 00000000..b821aeee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp new file mode 100644 index 00000000..074d0cfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp new file mode 100644 index 00000000..e803fb84 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp new file mode 100644 index 00000000..4bf4139d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_BLOCKWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp new file mode 100644 index 00000000..a571655c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp new file mode 100644 index 00000000..9ad9a630 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp new file mode 100644 index 00000000..4ee70702 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp new file mode 100644 index 00000000..8c5fa80e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp new file mode 100644 index 00000000..d2b81c48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp new file mode 100644 index 00000000..8d678e78 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); + +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp new file mode 100644 index 00000000..01056058 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp new file mode 100644 index 00000000..55c53dfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp new file mode 100644 index 00000000..367cf9a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp new file mode 100644 index 00000000..18fd0844 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp new file mode 100644 index 00000000..3d02f3cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp new file mode 100644 index 00000000..fcf072a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +// clang-format on +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp new file mode 100644 index 00000000..85d7ce8b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); +ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/utility/CMakeLists.txt b/library/src/utility/CMakeLists.txt new file mode 100644 index 00000000..69cc0e30 --- /dev/null +++ b/library/src/utility/CMakeLists.txt @@ -0,0 +1,31 @@ +## utility +set(UTILITY_SOURCE + device_memory.cpp + host_tensor.cpp + convolution_parameter.cpp +) + +add_library(utility STATIC ${UTILITY_SOURCE}) +add_library(composable_kernel::utility ALIAS utility) + +target_compile_features(utility PUBLIC) +set_target_properties(utility PROPERTIES POSITION_INDEPENDENT_CODE ON) + +target_include_directories(utility PUBLIC + "$" + "$" +) + +rocm_install( + TARGETS utility + EXPORT utilityTargets +) + +rocm_install( + EXPORT utilityTargets + FILE composable_kernelutilityTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) + +clang_tidy_check(utility) diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp new file mode 100644 index 00000000..b63b33a7 --- /dev/null +++ b/library/src/utility/convolution_parameter.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +ConvParams::ConvParams(ck::index_t n_dim, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(n_dim), + N_(n_batch), + K_(n_out_channels), + C_(n_in_channels), + filter_spatial_lengths_(filters_len), + input_spatial_lengths_(input_len), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(strides), + conv_filter_dilations_(dilations), + input_left_pads_(left_pads), + input_right_pads_(right_pads) +{ + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParams::ConvParams: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } +} + +ConvParams::ConvParams() + : ConvParams::ConvParams(2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}) +{ +} + +std::vector ConvParams::GetOutputSpatialLengths() const +{ + return output_spatial_lengths_; +} + +std::size_t ConvParams::GetFlops() const +{ + // 2 * N * K * C * * + return static_cast(2) * N_ * K_ * C_ * + std::accumulate(std::begin(output_spatial_lengths_), + std::end(output_spatial_lengths_), + static_cast(1), + std::multiplies()) * + std::accumulate(std::begin(filter_spatial_lengths_), + std::end(filter_spatial_lengths_), + static_cast(1), + std::multiplies()); +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::tensor_operation::device::ConvParams& p) +{ + os << "ConvParams {" + << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nN: " << p.N_ << "\nK: " << p.K_ + << "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ + << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ + << "\nconv_filter_strides: " << p.conv_filter_strides_ + << "\nconv_filter_dilations: " << p.conv_filter_dilations_ + << "\ninput_left_pads: " << p.input_left_pads_ + << "\ninput_right_pads: " << p.input_right_pads_; + + return os; +} diff --git a/library/src/utility/device_memory.cpp b/library/src/utility/device_memory.cpp new file mode 100644 index 00000000..bb4cadba --- /dev/null +++ b/library/src/utility/device_memory.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/device_utility/hip_check_error.hpp" + +#include "ck/library/utility/device_memory.hpp" + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +std::size_t DeviceMem::GetBufferSize() { return mMemSize; } + +void DeviceMem::ToDevice(const void* p) +{ + hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } + +DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } diff --git a/host/host_tensor/src/host_tensor.cpp b/library/src/utility/host_tensor.cpp similarity index 72% rename from host/host_tensor/src/host_tensor.cpp rename to library/src/utility/host_tensor.cpp index e840baf7..9223fedf 100644 --- a/host/host_tensor/src/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -1,7 +1,9 @@ -#include +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #include -#include "host_tensor.hpp" +#include "ck/library/utility/host_tensor.hpp" void HostTensorDescriptor::CalculateStrides() { @@ -26,15 +28,19 @@ std::size_t HostTensorDescriptor::GetElementSize() const std::size_t HostTensorDescriptor::GetElementSpace() const { - auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); - return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.size(); ++i) + { + space += (mLens[i] - 1) * mStrides[i]; + } + return space; } const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) +std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) { os << "dim " << desc.GetNumOfDimension() << ", "; @@ -44,5 +50,7 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream os << "strides {"; LogRange(os, desc.GetStrides(), ", "); - os << "}" << std::endl; + os << "}"; + + return os; } diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt new file mode 100644 index 00000000..08df8e4d --- /dev/null +++ b/profiler/CMakeLists.txt @@ -0,0 +1,48 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/ +) + +# ck_profiler +set(PROFILER_SOURCE + src/profiler.cpp + src/profile_gemm.cpp + src/profile_gemm_splitk.cpp + src/profile_gemm_bilinear.cpp + src/profile_gemm_bias_add_reduce.cpp + src/profile_gemm_add_add_fastgelu.cpp + src/profile_gemm_reduce.cpp + src/profile_batched_gemm.cpp + src/profile_batched_gemm_reduce.cpp + src/profile_grouped_gemm.cpp + src/profile_conv_fwd.cpp + src/profile_conv_fwd_bias_relu.cpp + src/profile_conv_fwd_bias_relu_add.cpp + src/profile_convnd_bwd_data.cpp + src/profile_conv_bwd_weight.cpp + src/profile_convnd_bwd_weight.cpp + src/profile_reduce.cpp + src/profile_normalization.cpp +) + +add_executable(ckProfiler ${PROFILER_SOURCE}) + +target_link_libraries(ckProfiler PRIVATE utility) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bilinear_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) +target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_weight_instance) +target_link_libraries(ckProfiler PRIVATE device_normalization_instance) +target_link_libraries(ckProfiler PRIVATE device_reduce_instance) diff --git a/profiler/README.md b/profiler/README.md new file mode 100644 index 00000000..bfd6a3a5 --- /dev/null +++ b/profiler/README.md @@ -0,0 +1,48 @@ +## Profile GEMM kernels +```bash +#arg1: tensor operation (gemm=GEMM) +#arg2: data type (0=fp32, 1=fp16) +#arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT) +#arg4: verification (0=no, 1=yes) +#arg5: initialization (0=no init, 1=integer value, 2=decimal value) +#arg6: print matrix value (0=no, 1=yes) +#arg7: run kernel # of times (>1) +#arg8 to 13: M, N, K, StrideA, StrideB, StrideC + +################ op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC +./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +```bash +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +.... +Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s +``` + +## Profile 2d forward convolution kernels +```bash +#arg1: tensor operation (conv=Convolution) +#arg2: data type (0=fp32, 1=fp16) +#arg3: input tensor layout (0=NCHW, 1=NHWC) +#arg4: weight tensor layout (0=KCYX, 1=KYXC) +#arg5: output tensor layout (0=NKHW, 1=NHWK) +#arg6: verification (0=no, 1=yes) +#arg7: initialization (0=no init, 1=integer value, 2=decimal value) +#arg8: print matrix value (0=no, 1=yes) +#arg9: run kernel # of times (>1) +#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx + ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +.... +Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s +``` diff --git a/composable_kernel/include/utility/data_type_enum.hpp b/profiler/include/data_type_enum.hpp similarity index 55% rename from composable_kernel/include/utility/data_type_enum.hpp rename to profiler/include/data_type_enum.hpp index 35df0067..afcd6fea 100644 --- a/composable_kernel/include/utility/data_type_enum.hpp +++ b/profiler/include/data_type_enum.hpp @@ -1,9 +1,11 @@ -#ifndef CK_DATA_TYPE_ENUM_HPP -#define CK_DATA_TYPE_ENUM_HPP +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once namespace ck { -enum DataTypeEnum_t +enum struct DataTypeEnum { Half = 0, Float = 1, @@ -16,4 +18,3 @@ enum DataTypeEnum_t }; } // namespace ck -#endif diff --git a/profiler/include/data_type_enum_helper.hpp b/profiler/include/data_type_enum_helper.hpp new file mode 100644 index 00000000..6f8ef2b9 --- /dev/null +++ b/profiler/include/data_type_enum_helper.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma + +#include "ck/utility/data_type.hpp" +#include "profiler/include/data_type_enum.hpp" + +namespace ck { + +template +struct get_datatype_from_enum; + +template <> +struct get_datatype_from_enum +{ + using type = int8_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = int32_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = half_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = float; +}; + +template <> +struct get_datatype_from_enum +{ + using type = double; +}; + +template +struct get_datatype_enum_from_type; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Int8; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Int32; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Half; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Float; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Double; +}; + +} // namespace ck diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp new file mode 100644 index 00000000..f8fa43ae --- /dev/null +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int BatchStrideA, + int BatchStrideB, + int BatchStrideC, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, 1, stride})); + } + }; + + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{})); + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + pass = pass & + ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp new file mode 100644 index 00000000..7920808f --- /dev/null +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&passthrough, &passthrough}; + + if(do_verification) + { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + d0_g_m_host_result(batch, m) = ck::type_convert(reduce0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(reduce1_acc); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + d1_g_m_device_result.mDesc.GetElementSpace()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops, + BatchCount); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); + float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); + float d1_error = check_error(d1_g_m_host_result, d1_g_m_device_result); + + pass = pass && (c_error < 1E-6); + pass = pass && (d0_error < 1E-6); + pass = pass && (d1_error < 1E-6); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_host: ", d0_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", d0_g_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", d1_g_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_bwd_weight_impl.hpp b/profiler/include/profile_conv_bwd_weight_impl.hpp new file mode 100644 index 00000000..7fa2556e --- /dev/null +++ b/profiler/include/profile_conv_bwd_weight_impl.hpp @@ -0,0 +1,285 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvBwdWeightNoOpPtr = + DeviceConvBwdWeightPtr; + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_conv_bwd_weight_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t split_k) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor wei_k_c_y_x_device_result( + f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + + auto ref_conv = ReferenceConvBwdWeightInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using DeviceConvBwdWeightNoOpPtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + + // add device Conv instances + std::vector conv_ptrs; + + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool pass = true; + + for(auto& conv_ptr : conv_ptrs) + { + // using atomic, so need to reset input + if(split_k > 1) + { + wei_device_buf.SetZero(); + } + + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + float max_error = check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); + + if(max_error > 8) + { + pass = false; + std::cout << "Fail info:" << conv_ptr->GetTypeString() << std::endl; + } + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_device: ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp new file mode 100644 index 00000000..28c95c37 --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvFwdBiasReluAddPtr = + DeviceConvFwdBiasActivationAddPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_fwd_bias_relu_add_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + // residual: assume same layout as output tensor + Tensor resi_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + std::cout << "resi_n_k_ho_wo: " << resi_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data()); + + using DeviceConvFwdBiasReluAddPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationAddPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(resi_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp new file mode 100644 index 00000000..478d4f37 --- /dev/null +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvFwdBiasReluPtr = + DeviceConvFwdBiasActivationPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_fwd_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k( + HostTensorDescriptor(std::vector({static_cast(K)}))); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_conv_fwd_impl.hpp b/profiler/include/profile_conv_fwd_impl.hpp new file mode 100644 index 00000000..65594b4c --- /dev/null +++ b/profiler/include/profile_conv_fwd_impl.hpp @@ -0,0 +1,285 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/convolution_forward.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +// FIXME: only support NCHW and NHWC layout, need to be more general +template +int profile_conv_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::tensor_operation::device::ConvParams& params) +{ + bool pass = true; + + // make host tensor descritpor + auto f_nhwc_host_tensor_descriptor = + [](ck::index_t n, ck::index_t c, std::vector spatial_lengths) { + std::vector nhwc_lengths{static_cast(n), + static_cast(c)}; + nhwc_lengths.insert( + nhwc_lengths.begin() + 1, spatial_lengths.begin(), spatial_lengths.end()); + + return HostTensorDescriptor(nhwc_lengths); + }; + + auto f_nchw_host_tensor_descriptor = + [](ck::index_t n, ck::index_t c, std::vector spatial_lengths) { + std::vector nchw_lengths{static_cast(n), + static_cast(c)}; + nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end()); + + return HostTensorDescriptor(nchw_lengths); + }; + + HostTensorDescriptor in_desc, wei_desc, out_desc; + + // FIXME: properly implement "make host descriptor" for different layout + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + in_desc = + f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_); + } + else if constexpr(is_same_v || + is_same_v || + is_same_v) + { + in_desc = + f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_); + } + + // FIXME: properly implement "make host descriptor" for different layout + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + wei_desc = + f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_); + } + else if constexpr(is_same_v || + is_same_v || + is_same_v) + { + wei_desc = + f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_); + } + + // FIXME: properly implement "make host descriptor" for different layout + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + out_desc = + f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()); + } + else if constexpr(is_same_v || + is_same_v || + is_same_v) + { + out_desc = + f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()); + } + + Tensor input(in_desc); + Tensor weight(wei_desc); + Tensor host_output(out_desc); + Tensor device_output(out_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceConvFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference op + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = params.GetFlops(); + std::size_t num_btype = params.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output.mData, host_output.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return 0; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_convnd_bwd_data_impl.hpp b/profiler/include/profile_convnd_bwd_data_impl.hpp new file mode 100644 index 00000000..78dbf2a1 --- /dev/null +++ b/profiler/include/profile_convnd_bwd_data_impl.hpp @@ -0,0 +1,485 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using INT8 = int8_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector&); + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector&); +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { +using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr; + +template +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +void get_device_conv_bwd_data_op_ptr( + InDataType, WeiDataType, OutDataType, std::vector&, int) +{ + std::cout << "can not find device conv bwd data" << std::endl; + exit(1); +} +template <> +void get_device_conv_bwd_data_op_ptr( + F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + INT8, INT8, INT8, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); + break; + default: break; + } +} + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + return true; +} +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_convnd_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + std::vector input_dims{static_cast(N), static_cast(C)}; + input_dims.insert( + std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); + + std::vector filter_dims{static_cast(K), static_cast(C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths)); + + std::vector output_dims{static_cast(N), static_cast(K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input_host_result( + get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor input_device_result( + get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor weights( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor output( + get_output_host_ensor_descriptor(output_dims, NDimSpatial)); + + std::cout << "input: " << input_host_result.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + output.GenerateTensorValue(GeneratorTensor_1{1}); + weights.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(output.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + if(do_verification) + { + auto RunReference = [&](auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input_host_result, + weights, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + }; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + RunReference(ref_conv); + } + + // add device Conv instances + std::vector conv_ptrs; + get_device_conv_bwd_data_op_ptr( + InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(input_device_result.mData.data()); + + if(!check_out(input_host_result, input_device_result)) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + + check_error(input_host_result, input_device_result); + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weights); + std::cout << std::endl; + + std::cout << "out_host : "; + show_data_nhwc_layout(input_host_result); + std::cout << std::endl; + + std::cout << "out_device: "; + show_data_nhwc_layout(input_device_result); + std::cout << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + return success; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_convnd_bwd_weight_impl.hpp b/profiler/include/profile_convnd_bwd_weight_impl.hpp new file mode 100644 index 00000000..5660f51e --- /dev/null +++ b/profiler/include/profile_convnd_bwd_weight_impl.hpp @@ -0,0 +1,478 @@ +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvndBwdWeightNoOpPtr = + DeviceConvBwdWeightPtr; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances( + std::vector&); +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector&); + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances( + std::vector&); +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector&); + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +using DeviceConvndBwdWeightNoOpPtr = + ck::tensor_operation::device::instance::DeviceConvndBwdWeightNoOpPtr; + +template +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +void get_device_conv_bwd_weight_op_ptr( + InDataType, WeiDataType, OutDataType, std::vector&, int) +{ + std::cout << "can not find device conv bwd weight" << std::endl; + exit(1); +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + break; + default: break; + } +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + break; + default: break; + } +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + break; + default: break; + } +} + +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_convnd_bwd_weight_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t split_k) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + std::vector input_dims{static_cast(N), static_cast(C)}; + input_dims.insert( + std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); + + std::vector filter_dims{static_cast(K), static_cast(C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths)); + + std::vector output_dims{static_cast(N), static_cast(K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor weights_host_result( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor weights_device_result( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor output( + get_output_host_ensor_descriptor(output_dims, NDimSpatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights_host_result.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + output.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_1{1}); + output.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + out_device_buf.ToDevice(output.mData.data()); + + // reset input to zero + wei_device_buf.SetZero(); + + if(do_verification) + { + auto RunReference = [&](auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights_host_result, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + }; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight(); + RunReference(ref_conv); + } + + // add device Conv instances + std::vector conv_ptrs; + get_device_conv_bwd_weight_op_ptr( + InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + // using atomic, so need to reset input, setzero is done in invoker + // if(split_k > 1) + //{ + // wei_device_buf.SetZero(); + //} + + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + if(!conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + continue; + } + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + std::string conv_name = conv_ptr->GetTypeString(); + float ave_time = 0; + + if(std::is_same::value && split_k > 1) + { + // alloc work space + size_t bwd_weight_workspace_size = conv_ptr->GetWorkSpaceSize(argument_ptr.get()); + if(bwd_weight_workspace_size <= 0) + { + printf("wrong work space size\n"); + exit(1); + } + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv_ptr->SetWorkSpacePointer(argument_ptr.get(), + wei_work_space_device_buf.GetDeviceBuffer()); + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + } + else + { + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + } + + std::size_t flop = + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + wei_device_buf.FromDevice(weights_device_result.mData.data()); + + float max_error = check_error(weights_host_result, weights_device_result); + + if(max_error > 8) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + + check_error(weights_host_result, weights_device_result); + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weights_host_result); + std::cout << std::endl; + + std::cout << "out : "; + show_data_nhwc_layout(input); + std::cout << std::endl; + + std::cout << "wei_device: "; + show_data_nhwc_layout(weights_device_result); + std::cout << std::endl; + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + return success; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp new file mode 100644 index 00000000..e2f2b643 --- /dev/null +++ b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template // assume Ds and E have same layout +bool profile_gemm_add_add_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddAddFastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + DELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && + ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profile_gemm_bias_add_reduce_impl.hpp new file mode 100644 index 00000000..ac0ee0bd --- /dev/null +++ b/profiler/include/profile_gemm_bias_add_reduce_impl.hpp @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmBiasAddReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_add_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideD0) +{ + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor d0_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduce0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduce0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + using D0ElementOp = PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto d0_element_op = D0ElementOp{}; + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + ReduceAccDataType acc = static_cast(c_m_n_host_result(m, n)) + + static_cast(bias_n(n)); + + ReduceAccDataType d0 = static_cast(d0_m_n(m, n)); + c_element_op(acc, acc); + d0_element_op(d0, d0); + acc += d0; + c_m_n_host_result(m, n) = static_cast(acc); + } + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpace()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + gemm_element_ops, + {&d0_element_op}, + reduce_in_element_ops, + reduce_out_element_ops); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData); + ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_host: ", reduce0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", reduce0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_host: ", reduce1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", reduce1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bilinear_impl.hpp b/profiler/include/profile_gemm_bilinear_impl.hpp new file mode 100644 index 00000000..d138dda1 --- /dev/null +++ b/profiler/include/profile_gemm_bilinear_impl.hpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template // assume Ds and E have same layout +bool profile_gemm_bilinear_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD, + int StrideE, + float alpha, + float beta) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = Bilinear; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + DELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && + ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp new file mode 100644 index 00000000..6bfd0ffe --- /dev/null +++ b/profiler/include/profile_gemm_impl.hpp @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +int profile_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference op + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = + pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_avg_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass ? 0 : 1; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp new file mode 100644 index 00000000..ba2bb6c8 --- /dev/null +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -0,0 +1,356 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor reduce1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpace()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpace()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData); + ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_host: ", reduce0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", reduce0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_host: ", reduce1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", reduce1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_splitk_impl.hpp b/profiler/include/profile_gemm_splitk_impl.hpp new file mode 100644 index 00000000..7a820f41 --- /dev/null +++ b/profiler/include/profile_gemm_splitk_impl.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_splitk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = + pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp new file mode 100644 index 00000000..80c7b89e --- /dev/null +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_grouped_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0{}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + // if(do_verification) + // { + + // } + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_shapes; + + gemm_shapes.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace())); + + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); + + gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_c, + gemm_shapes, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} // namespace profiler + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_normalization_impl.hpp b/profiler/include/profile_normalization_impl.hpp new file mode 100644 index 00000000..ac7f8218 --- /dev/null +++ b/profiler/include/profile_normalization_impl.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_instances(std::vector&); +void add_device_softmax_f16_f16_rank4_instances(std::vector&); + +void add_device_softmax_f32_f32_rank3_instances(std::vector&); +void add_device_softmax_f32_f32_rank4_instances(std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +enum struct NormType +{ + LAYERNORM, + BATCHNORM, + SOFTMAX, +}; + +enum struct NormDataType +{ + F32_F32, // in, out + F16_F16, + BF16_BF16, + INT8_INT8, +}; + +// clang-format off +template std::string type_to_string(); +template <> std::string type_to_string() { return "f32"; } +template <> std::string type_to_string() { return "f16"; } +template <> std::string type_to_string() { return "bf16"; } +template <> std::string type_to_string() { return "int8"; } +template <> std::string type_to_string() { return "int32"; } +// clang-format on + +template +void profile_normalization_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, + std::vector in_strides, + std::vector reduce_dims, + AccDataType alpha, + AccDataType beta, + NormType norm_type) +{ + Tensor in = in_strides.empty() ? Tensor(in_length) + : Tensor(in_length, in_strides); + Tensor out(in.mDesc); + + switch(init_method) + { + // case 0: break; + case 0: + in.GenerateTensorValue(GeneratorTensor_1{}); + out.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor out_ref(out); + + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + in_dev.ToDevice(in.mData.data()); + out_dev.ToDevice(out.mData.data()); + + std::vector i_in_lengths(in.mDesc.GetLengths().begin(), in.mDesc.GetLengths().end()); + std::vector i_in_strides(in.mDesc.GetStrides().begin(), in.mDesc.GetStrides().end()); + + // add device normalization instances + std::vector instances; + + if(norm_type == NormType::SOFTMAX) + { + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if(in_length.size() == 3) + tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances( + instances); + + if(in_length.size() == 4) + tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances( + instances); + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if(in_length.size() == 3) + tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances( + instances); + + if(in_length.size() == 4) + tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances( + instances); + } + } + + if(instances.size() <= 0) + { + throw std::runtime_error("wrong! no device normalization instance found"); + } + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + for(auto& inst_ptr : instances) + { + // Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3 + // problem to rank 4 kernel) other than invoking IsSupportedArgument()? + if(!(inst_ptr->GetRank() == static_cast(i_in_lengths.size()) && + inst_ptr->GetNumReduceDim() == static_cast(reduce_dims.size()))) + { + continue; + } + + auto argument_ptr = inst_ptr->MakeArgumentPointer(i_in_lengths, + i_in_strides, + reduce_dims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = [", in_length, ", ") + << "], " + << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + return; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + in.mDesc.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + // TODO: factory method to dynamically switch between different reference normalizations + using ReferenceFactory = + tensor_operation::host::ReferenceSoftmax; + + ReferenceFactory{}.MakeInvoker().Run({in, out_ref, alpha, beta, reduce_dims}); + + out_dev.FromDevice(out.mData.data()); + + bool pass; + if(std::is_same::value) + { + pass = ck::utils::check_err( + out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1); + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_ref : ", out_ref.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + } + } + else + { + pass = ck::utils::check_err(out.mData, out_ref.mData); + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_ref : ", out_ref.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + } + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "input lengths = [", in_length, ", ") + << "], " + << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + } + } + } + std::cout << "Best Perf for datatype = " << type_to_string() << "_" + << type_to_string() << ", "; + LogRange(std::cout << "length = ", i_in_lengths, ",") << ", "; + LogRange(std::cout << "stride = ", i_in_strides, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", "; + std::cout << "alpha = " << alpha << ", " + << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_instance_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp new file mode 100644 index 00000000..5e801770 --- /dev/null +++ b/profiler/include/profile_reduce_impl.hpp @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_reduction.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct ReduceDescription +{ + static constexpr int Rank_ = Rank; + static constexpr int NumReduceDim_ = NumReduceDim; + static constexpr int ReduceOpId_ = ReduceOpId; + static constexpr int PropagateNan_ = PropagateNan; + static constexpr int UseIndex_ = UseIndex; +}; + +using reduce_description_instances = + std::tuple, // for ADD + ReduceDescription<4, 4, 0, false, false>, + ReduceDescription<4, 1, 0, false, false>, + ReduceDescription<2, 1, 0, false, false>, + + ReduceDescription<4, 3, 5, false, false>, // for AVG + ReduceDescription<4, 4, 5, false, false>, + ReduceDescription<4, 1, 5, false, false>, + ReduceDescription<2, 1, 5, false, false>, + + ReduceDescription<4, 3, 7, false, false>, // for NORM2 + ReduceDescription<4, 4, 7, false, false>, + ReduceDescription<4, 1, 7, false, false>, + ReduceDescription<2, 1, 7, false, false>, + + ReduceDescription<4, 3, 2, false, false>, // for MIN + ReduceDescription<4, 4, 2, false, false>, + ReduceDescription<4, 1, 2, false, false>, + ReduceDescription<2, 1, 2, false, false>, + ReduceDescription<4, 3, 3, false, false>, // for MAX + ReduceDescription<4, 4, 3, false, false>, + ReduceDescription<4, 1, 3, false, false>, + ReduceDescription<2, 1, 3, false, false>, + ReduceDescription<4, 3, 4, false, false>, // for AMAX + ReduceDescription<4, 4, 4, false, false>, + ReduceDescription<4, 1, 4, false, false>, + ReduceDescription<2, 1, 4, false, false>, + + ReduceDescription<4, 3, 2, false, true>, // for MIN + ReduceDescription<4, 4, 2, false, true>, + ReduceDescription<4, 1, 2, false, true>, + ReduceDescription<2, 1, 2, false, true>, + ReduceDescription<4, 3, 3, false, true>, // for MAX + ReduceDescription<4, 4, 3, false, true>, + ReduceDescription<4, 1, 3, false, true>, + ReduceDescription<2, 1, 3, false, true>, + ReduceDescription<4, 3, 4, false, true>, // for AMAX + ReduceDescription<4, 4, 4, false, true>, + ReduceDescription<4, 1, 4, false, true>, + ReduceDescription<2, 1, 4, false, true>>; + +template +bool description_match(const DescriptionType& description, + int Rank, + const std::vector& reduceDims, + ReduceTensorOp ReduceOpId, + bool PropagateNan, + bool UseIndex) +{ + if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast(ReduceOpId) || + description.PropagateNan_ != static_cast(PropagateNan) || + description.UseIndex_ != static_cast(UseIndex)) + return (false); + + if(DescriptionType::NumReduceDim_ != reduceDims.size()) + return (false); + + bool result = true; + + return (result); +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +static inline std::vector get_invariant_dims(const std::vector& reduceDims) +{ + assert(NumReduceDim == reduceDims.size()); + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + std::vector invariantDims; + + // collect invariant dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + invariantDims.push_back(i); + }; + + return invariantDims; +}; + +template +bool profile_reduce_impl_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + float alpha, + float beta) +{ + using namespace ck::tensor_operation::device; + using namespace ck::tensor_operation::device::instance; + using ck::host_common::dumpBufferToFile; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + constexpr bool OutputIndex = (op_support_indices && UseIndex); + + constexpr bool out_support_atomic_add = std::is_same::value; + constexpr bool op_support_atomic_add = + !op_support_indices && ReduceOpId != ReduceTensorOp::NORM2; + constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add); + + // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations + // 2) If InDataType is half_t, must use float as AccDataType for non-indexable reduction + // operations + constexpr bool invalid_reduce_1 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InDataType is float, must use float as AccDataType for indexable reduction operations + constexpr bool invalid_reduce_2 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // 1) The indices can only be used when the reduction operation is indexable + constexpr bool invalid_reduce_3 = (!op_support_indices && UseIndex); + + // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations + // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction + // operations + constexpr bool invalid_reduce_4 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InDataType is int8_t, the supported operation must be either indexable operations or + // ADD/AVG + constexpr bool invalid_reduce_5 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); + + // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations + constexpr bool invalid_reduce_6 = + std::is_same::value && !std::is_same::value; + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || + invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); + + bool pass = true; + + if constexpr(!invalid_reduce) + { + Tensor in(inLengths); + + std::vector outLengths; + + const auto invariantDims = get_invariant_dims(reduceDims); + + if(reduceDims.size() == Rank) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int) : 0; + + DeviceMem out_indices_dev(indicesSizeInBytes); + + float best_avg_time = 0; + float best_gb_per_sec = 0; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + using ReduceOperation = typename reduce_binary_operator::opType; + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + using DeviceReduceInstPtr0 = + DeviceReducePtr; + + std::vector reduce0_ptrs; + + add_device_reduce_instance_threadwise(reduce0_ptrs); + + add_device_reduce_instance_blockwise(reduce0_ptrs); + + if constexpr(use_atomic_add) + { + add_device_reduce_instance_multiblock_atomic_add(reduce0_ptrs); + } + + if(reduce0_ptrs.empty()) + { + throw std::runtime_error("Wrong! No device REDUCE instance found"); + }; + + if(do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); + }; + + std::vector i_inLengths; + std::vector i_inStrides; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths.assign(inLengths.begin(), inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + for(auto& reduce_ptr : reduce0_ptrs) + { + auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, + i_inStrides, + i_outLengths, + i_outStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_indices_dev.GetDeviceBuffer(), + in_elementwise_op, + acc_elementwise_op); + + if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) + continue; + + std::string reduce_name = reduce_ptr->GetTypeString(); + + auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + invariant_total_length * reduce_total_length * sizeof(InDataType) + + invariant_total_length * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << reduce_name << std::endl; + + if(gb_per_sec > best_gb_per_sec) + { + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + bool single_pass; + + out_dev.FromDevice(out.mData.data()); + single_pass = ck::utils::check_err(out.mData, out_ref.mData); + + if(OutputIndex) + { + out_indices_dev.FromDevice(out_indices.mData.data()); + single_pass = single_pass && + ck::utils::check_err(out_indices.mData, out_indices_ref.mData); + }; + + if(!single_pass) + { + std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; + } + + pass = pass && single_pass; + }; + + if(do_dumpout) + { + dumpBufferToFile("dump_in.bin", in.mData.data(), in.mDesc.GetElementSize()); + dumpBufferToFile("dump_out.bin", out.mData.data(), out.mDesc.GetElementSize()); + dumpBufferToFile( + "dump_out_host.bin", out_ref.mData.data(), out_ref.mDesc.GetElementSize()); + if(OutputIndex) + { + dumpBufferToFile("dump_indices.bin", + out_indices.mData.data(), + out_indices.mDesc.GetElementSize()); + dumpBufferToFile("dump_indices_host.bin", + out_indices_ref.mData.data(), + out_indices_ref.mDesc.GetElementSize()); + }; + }; + }; + + if(time_kernel) + std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s" + << std::endl; + } + else + { + std::cout << "The requested reduction operation is not supported, please check !!!" + << std::endl; + }; + + return pass; +}; + +template +bool profile_reduce_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + ReduceTensorOp ReduceOpId, + bool PropagateNan, + bool UseIndex, + float alpha, + float beta) +{ + bool matched = false; + bool pass = true; + + using tuple_of_description_instances = + tensor_operation::device::instance::reduce_description_instances; + + const auto tuple_object = tuple_of_description_instances{}; + + static_for<0, std::tuple_size::value, 1>{}([&](auto i) { + if(matched) + return; + + using descType = remove_cvref_t(tuple_object))>; + + if(!description_match( + descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex)) + return; + + pass = pass && + profile_reduce_impl_impl(descType::ReduceOpId_), + static_cast(descType::PropagateNan_), + static_cast(descType::UseIndex_)>(do_verification, + init_method, + do_dumpout, + time_kernel, + inLengths, + reduceDims, + alpha, + beta); + + matched = true; + }); + + return pass; +}; + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp new file mode 100644 index 00000000..7c4e2f7b --- /dev/null +++ b/profiler/src/profile_batched_gemm.cpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/include/profile_batched_gemm_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +int profile_batched_gemm(int argc, char* argv[]) +{ + if(argc != 18) + { + // clang-format off + printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n"); + printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n"); + printf(" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n"); + printf(" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n"); + printf(" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchStrideA = std::stoi(argv[14]); + const int BatchStrideB = std::stoi(argv[15]); + const int BatchStrideC = std::stoi(argv[16]); + + const int BatchCount = std::stoi(argv[17]); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + bool pass = ck::profiler:: + profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp new file mode 100644 index 00000000..7c518e97 --- /dev/null +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_batched_gemm_reduce_impl.hpp" + +int profile_batched_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 15 || argc == 16)) + { + printf("arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); + printf("arg15: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchCount = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_bwd_weight.cpp b/profiler/src/profile_conv_bwd_weight.cpp new file mode 100644 index 00000000..989c4808 --- /dev/null +++ b/profiler/src/profile_conv_bwd_weight.cpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_conv_bwd_weight_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_bwd_weight(int argc, char* argv[]) +{ + if(argc != 26) + { + printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + printf("arg25: split k (>=1)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + ck::index_t split_k = std::stoi(argv[25]); + split_k = std::max(1, split_k); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}, + split_k); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}, + split_k); + } + else + { + throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_fwd.cpp b/profiler/src/profile_conv_fwd.cpp new file mode 100644 index 00000000..d888519a --- /dev/null +++ b/profiler/src/profile_conv_fwd.cpp @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_conv_fwd_impl.hpp" + +enum struct ConvLayout +{ + NCHW_KYXC_NKHW, // 0 + NHWC_KYXC_NHWK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +static void print_helper_msg() +{ + // clang-format-off + std::cout << "arg1: tensor operation (conv_fwd: ForwardConvolution)\n" + << "arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n" + << "arg3: tensor layout (0: Input[N, C, Hi, Wi] * Weight[K, C, Y, X] = Output[N, K, " + "Ho, Wo]\n" + << " 1: Input[N, Hi, Wi, C] * Weight[K, Y, X, C] = Output[N, Ho, " + "Wo, K])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8: N spatial dimensions\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; + // clang-format-on +} + +ck::tensor_operation::device::ConvParams +parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) +{ + const ck::index_t N = std::stoi(argv[arg_idx++]); + const ck::index_t K = std::stoi(argv[arg_idx++]); + const ck::index_t C = std::stoi(argv[arg_idx++]); + + std::vector filter_spatial_lengths(num_dim_spatial); + std::vector input_spatial_lengths(num_dim_spatial); + std::vector conv_filter_strides(num_dim_spatial); + std::vector conv_filter_dilations(num_dim_spatial); + std::vector input_left_pads(num_dim_spatial); + std::vector input_right_pads(num_dim_spatial); + + for(int i = 0; i < num_dim_spatial; ++i) + { + filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return ck::tensor_operation::device::ConvParams{num_dim_spatial, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; +} + +int profile_conv_fwd(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 3 for N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + exit(1); + } + + const auto params = parse_conv_params(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_type, + auto wei_type, + auto out_type, + auto in_layout, + auto wei_layout, + auto out_layout) { + constexpr ck::index_t NumDimSpatial = num_dim_spatial_tmp.value; + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + bool pass = ck::profiler::profile_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp new file mode 100644 index 00000000..91f4836a --- /dev/null +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_conv_fwd_bias_relu_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp new file mode 100644 index 00000000..5cc6faba --- /dev/null +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) +{ + if(argc != 25) + { + printf( + "arg1: tensor operation (conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLu+Add)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_add_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_convnd_bwd_data.cpp b/profiler/src/profile_convnd_bwd_data.cpp new file mode 100644 index 00000000..7c387d37 --- /dev/null +++ b/profiler/src/profile_convnd_bwd_data.cpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_convnd_bwd_data_impl.hpp" + +namespace { + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // namespace + +int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) +{ + const int preParams = 10; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + preParams; + if(cmdline_nargs != argc) + { + printf("arg1: tensor operation (conv[1|2|3]d_bwd_data: BackwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); + + auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + using AccDataType = decltype(acc_type); + + switch(num_dim_spatial) + { + case 1: + ck::profiler::profile_convnd_bwd_data_impl<1, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); + break; + + case 2: + ck::profiler::profile_convnd_bwd_data_impl<2, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); + break; + + case 3: + ck::profiler::profile_convnd_bwd_data_impl<3, + InDataType, + WeiDataType, + OutDataType, + AccDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); + break; + + default: break; + } + }; + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(float{}, float{}, float{}, float{}); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(int8_t{}, int8_t{}, int8_t{}, int32_t{}); + } + else + { + std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl; + return 1; + } + + return 0; +} diff --git a/profiler/src/profile_convnd_bwd_weight.cpp b/profiler/src/profile_convnd_bwd_weight.cpp new file mode 100644 index 00000000..741d9ac6 --- /dev/null +++ b/profiler/src/profile_convnd_bwd_weight.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_convnd_bwd_weight_impl.hpp" + +namespace { + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // namespace + +int profile_convnd_bwd_weight(int argc, char* argv[], int num_dim_spatial) +{ + const int preParams = 11; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + preParams; + if(cmdline_nargs != argc) + { + printf("arg1: tensor operation (convnd[1|2|3]d_bwd_weight: BackwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16, 2: bf16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10: splitk\n"); + printf("arg11 to 25: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + ck::index_t split_k = std::stoi(argv[10]); + split_k = std::max(1, split_k); + + ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); + + auto Run = [&](auto input_type, auto wei_type, auto out_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + switch(num_dim_spatial) + { + case 1: + ck::profiler::profile_convnd_bwd_weight_impl<1, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + split_k); + break; + + case 2: + ck::profiler::profile_convnd_bwd_weight_impl<2, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + split_k); + break; + + case 3: + ck::profiler::profile_convnd_bwd_weight_impl<3, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + do_verification, + init_method, + do_log, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + split_k); + break; + + default: break; + } + }; + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(float{}, float{}, float{}); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::half_t{}, ck::half_t{}, ck::half_t{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}); + } + else + { + std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl; + return 1; + } + + return 0; +} diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp new file mode 100644 index 00000000..f53f4781 --- /dev/null +++ b/profiler/src/profile_gemm.cpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +static void print_helper_msg() +{ + std::cout << "arg1: tensor operation (gemm: GEMM)\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" + << " 1: A[m, k] * B[n, k] = C[m, n];\n" + << " 2: A[k, m] * B[k, n] = C[m, n];\n" + << " 3: A[k, m] * B[n, k] = C[m, n])\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n" + << std::endl; +} + +int profile_gemm(int argc, char* argv[]) +{ + if(argc != 14) + { + print_helper_msg(); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + using INT32 = int32_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = + ck::profiler::profile_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_gemm_add_add_fastgelu.cpp b/profiler/src/profile_gemm_add_add_fastgelu.cpp new file mode 100644 index 00000000..a381222c --- /dev/null +++ b/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_add_add_fastgelu_impl.hpp" + +int profile_gemm_add_add_fastgelu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN_MN, // 0 + MK_NK_MN_MN_MN, // 1 + KM_KN_MN_MN_MN, // 2 + KM_NK_MN_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 16) + { + // clang-format off + printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n"); + printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 3: E[m, n] = FastGeLU(A[k, m] * B[n, k] + D0[m, n] + D1[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideE = std::stoi(argv[15]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto e_type, + auto a_layout, + auto b_layout, + auto de_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using DELayout = decltype(de_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_gemm_bias_add_reduce.cpp b/profiler/src/profile_gemm_bias_add_reduce.cpp new file mode 100644 index 00000000..bc267570 --- /dev/null +++ b/profiler/src/profile_gemm_bias_add_reduce.cpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_bias_add_reduce_impl.hpp" + +int profile_gemm_bias_add_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+bias+add+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int StrideC1 = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_bilinear.cpp b/profiler/src/profile_gemm_bilinear.cpp new file mode 100644 index 00000000..14c57789 --- /dev/null +++ b/profiler/src/profile_gemm_bilinear.cpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_bilinear_impl.hpp" + +int profile_gemm_bilinear(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32, // 0 + F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 17) + { + // clang-format off + printf("arg1: tensor operation (gemm_bilinear: GEMM+Bilinear)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = alpha * A[m, k] * B[k, n] + beta * D[m, n];\n"); + printf(" 1: E[m, n] = alpha * A[m, k] * B[n, k] + beta * D[m, n];\n"); + printf(" 2: E[m, n] = alpha * A[k, m] * B[k, n] + beta * D[m, n];\n"); + printf(" 3: E[m, n] = alpha * A[k, m] * B[n, k] + beta * D[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + printf("arg15 to 16: alhpa, beta\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + const float alpha = std::stof(argv[15]); + const float beta = std::stof(argv[16]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d_type, + auto e_type, + auto a_layout, + auto b_layout, + auto de_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using DDataType = decltype(d_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using DELayout = decltype(de_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_bilinear_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD < 0) ? DefaultStrideD : StrideD, + (StrideE < 0) ? DefaultStrideE : StrideE, + alpha, + beta); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp new file mode 100644 index 00000000..476943c8 --- /dev/null +++ b/profiler/src/profile_gemm_reduce.cpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_reduce_impl.hpp" + +int profile_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp new file mode 100644 index 00000000..fff023c8 --- /dev/null +++ b/profiler/src/profile_gemm_splitk.cpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_splitk_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +int profile_gemm_splitk(int argc, char* argv[]) +{ + if(argc != 15) + { + printf("arg1: tensor operation (gemm_splitk: Split-K GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int KBatch = std::stoi(argv[14]); + + using F32 = float; + using F16 = ck::half_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_splitk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp new file mode 100644 index 00000000..a51505ae --- /dev/null +++ b/profiler/src/profile_grouped_gemm.cpp @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_grouped_gemm_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm(int argc, char* argv[]) +{ + if(!(argc == 14)) + { + printf("arg1: tensor operation (grouped_gemm: Grouped GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideCs = argToIntArray(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_normalization.cpp b/profiler/src/profile_normalization.cpp new file mode 100644 index 00000000..277a78a6 --- /dev/null +++ b/profiler/src/profile_normalization.cpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/include/profile_normalization_impl.hpp" + +using ck::index_t; +using ck::profiler::NormDataType; +using ck::profiler::NormType; + +struct ArgParser +{ + std::unordered_map norm_dict = {{"layernorm", NormType::LAYERNORM}, + {"batchnorm", NormType::BATCHNORM}, + {"softmax", NormType::SOFTMAX}}; + + std::unordered_map> long_opts = { + {"length", {}}, {"stride", {}}, {"reduce", {}}, {"alpha", {}}, {"beta", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help() +{ + std::cout << "arg1: tensor operation (layernorm/batchnorm/softmax)\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg3: verification (0: no; 1: yes)\n" + << "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg5: print tensor value (0: no; 1: yes)\n" + << "arg6: time kernel (0=n0, 1=yes)\n" + << "--length: tensor extents (e.g, --length 8 4 256) \n" + << "--stride: tensor strides (e.g, --stride 1024 256 1)\n" + << "--reduce: to-reduce dimensions (e.g, --reduce 2)\n" + << "--alpha: alpha scaling value\n" + << "--beta: beta scaling value\n" + << std::endl; +} + +int profile_normalization(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help(); + return 0; + } + + ArgParser arg_parser; + + // short unnamed options + const NormType norm_type = arg_parser.norm_dict[argv[1]]; + const NormDataType data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + const std::vector stride = arg_parser.long_opts["stride"]; + const std::vector reduce = arg_parser.long_opts["reduce"]; + const index_t alpha = + arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0]; + const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0]; + + if(data_type == NormDataType::F16_F16) + { + ck::profiler::profile_normalization_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta), + norm_type); + } + else if(data_type == NormDataType::F32_F32) + { + ck::profiler::profile_normalization_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta), + norm_type); + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +// hijack main() for quick debugging +// int main(int argc, char* argv[]) +// { +// profile_normalization(argc, argv); +// return 0; +// } diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp new file mode 100644 index 00000000..1ec2a6d6 --- /dev/null +++ b/profiler/src/profile_reduce.cpp @@ -0,0 +1,431 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/utility/host_common_util.hpp" + +#include "profiler/include/profile_reduce_impl.hpp" +#include "profiler/include/data_type_enum.hpp" + +using namespace std; + +using ck::ReduceTensorOp; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDims", required_argument, nullptr, 'R'}, + {"reduceOp", required_argument, nullptr, 'O'}, + {"compType", required_argument, nullptr, 'C'}, + {"outType", required_argument, nullptr, 'W'}, + {"nanOpt", required_argument, nullptr, 'N'}, + {"indicesOpt", required_argument, nullptr, 'I'}, + {"scales", required_argument, nullptr, 'S'}, + {"half", no_argument, nullptr, '?'}, + {"double", no_argument, nullptr, '?'}, + {"int8", no_argument, nullptr, '?'}, + {"bf16", no_argument, nullptr, '?'}, + {"dumpout", required_argument, nullptr, 'o'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +static void check_reduce_dims(const int rank, const std::vector& reduceDims) +{ + for(auto dim : reduceDims) + { + if(dim < 0 || dim >= rank) + throw std::runtime_error("Invalid dimension index specified for Reducing"); + }; + + unsigned int flag = 0; + + for(auto dim : reduceDims) + { + if(flag & (0x1 << dim)) + throw std::runtime_error("All toReduce dimensions should be different!"); + flag = flag | (0x1 << dim); + }; +}; + +class ReduceProfilerArgs +{ + private: + int option_index = 0; + + public: + bool use_half = false; + bool use_double = false; + bool use_int8 = false; + bool use_bf16 = false; + + std::vector inLengths; + std::vector outLengths; + std::vector reduceDims; + + std::vector scales; + + ReduceTensorOp reduceOp = ReduceTensorOp::ADD; + ck::DataTypeEnum compTypeId = ck::DataTypeEnum::Float; + ck::DataTypeEnum outTypeId = ck::DataTypeEnum::Float; + + bool compType_assigned = false; + bool outType_assigned = false; + + int nanOpt = 0; + int indicesOpt = 0; + bool do_verification = false; + bool do_dumpout = false; + + int init_method; + bool time_kernel; + + ReduceProfilerArgs() = default; + ~ReduceProfilerArgs() = default; + + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions" + << std::endl; + std::cout << "--reduceOp or -O, enum value indicating the reduction operations" + << std::endl; + std::cout << "--compType or -C, enum value indicating the type of accumulated values used " + "during the reduction" + << std::endl; + std::cout << "--outType or -W, optional enum value indicating the type of the reduced " + "output, which could be float when the input data is half" + << std::endl; + std::cout + << "--nanOpt or -N, 1/0 value indicates the selection to use or not use Nan-Propagation" + << std::endl; + std::cout << "--indicesOpt or -I, 1/0 value indicates the selection to use or not use " + "index in reduction" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "--half, use fp16 for the input and output tensor data types" << std::endl; + std::cout << "--double, use fp64 for the input and output tensor data types" << std::endl; + std::cout << "--int8, use int8 for the input and output tensor data types" << std::endl; + std::cout << "--bf16, use bfloat16 for the input and output tensor data types" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "--dumpout or -o, 1/0 to indicate where to save the reduction result to files " + "for further analysis" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + optind++; // to skip the "reduce" module name + + while(1) + { + ch = getopt_long(argc, argv, "D:R:O:C:W:N:I:S:v:o:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'O': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceOp = static_cast(std::atoi(optarg)); + break; + case 'C': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + compTypeId = static_cast(std::atoi(optarg)); + compType_assigned = true; + break; + case 'W': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + outTypeId = static_cast(std::atoi(optarg)); + outType_assigned = true; + break; + case 'N': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + nanOpt = std::atoi(optarg); + break; + case 'I': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + indicesOpt = std::atoi(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + + if(scales.size() != 2) + throw std::runtime_error("Invalid option format!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case 'o': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_dumpout = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "half") + use_half = true; + else if(std::string(long_options[option_index].name) == "double") + use_double = true; + else if(std::string(long_options[option_index].name) == "int8") + use_int8 = true; + else if(std::string(long_options[option_index].name) == "bf16") + use_bf16 = true; + else if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + + default: + show_usage(argv[0]); + std::cerr << "Invalid cmd-line options!" << std::endl; + return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(reduceOp == ReduceTensorOp::MIN || reduceOp == ReduceTensorOp::MAX || + reduceOp == ReduceTensorOp::AMAX) + { + // for indexable operations, no need to assign compType and outType, just let them be + // same as inType + compType_assigned = false; + outType_assigned = false; + }; + + return (0); + }; + +}; // end of class AppArgs + +int profile_reduce(int argc, char* argv[]) +{ + using ck::DataTypeEnum; + using ck::profiler::profile_reduce_impl; + + ReduceProfilerArgs args; + + if(args.processArgs(argc, argv) < 0) + return (-1); + + int rank = args.inLengths.size(); + + check_reduce_dims(rank, args.reduceDims); + + if(args.reduceOp == ReduceTensorOp::MUL || args.reduceOp == ReduceTensorOp::NORM1) + throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!"); + + if(args.use_half) + { + if(!args.compType_assigned) + args.compTypeId = DataTypeEnum::Half; + + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::Half && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::Half; + + if(args.compTypeId == DataTypeEnum::Half) + { + profile_reduce_impl( + args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Float) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + } + else if(args.use_double) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.use_int8) + { + if(!args.compType_assigned) + args.compTypeId = DataTypeEnum::Int8; + + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::Int8 && args.outTypeId != DataTypeEnum::Int32)) + args.outTypeId = DataTypeEnum::Int32; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::Int8; + + if(args.compTypeId == DataTypeEnum::Int8) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Int32) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + } + else if(args.use_bf16) + { + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::BFloat16 && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::BFloat16; + + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + { + if(args.compTypeId == DataTypeEnum::Float) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Double) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + }; + + return (0); +}; diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp new file mode 100644 index 00000000..bc104021 --- /dev/null +++ b/profiler/src/profiler.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +int profile_gemm(int, char*[]); +int profile_gemm_splitk(int, char*[]); +int profile_gemm_bilinear(int, char*[]); +int profile_gemm_add_add_fastgelu(int, char*[]); +int profile_gemm_reduce(int, char*[]); +int profile_gemm_bias_add_reduce(int, char*[]); +int profile_batched_gemm(int, char*[]); +int profile_batched_gemm_reduce(int, char*[]); +int profile_grouped_gemm(int, char*[]); +int profile_conv_fwd(int, char*[]); +int profile_conv_fwd_bias_relu(int, char*[]); +int profile_conv_fwd_bias_relu_add(int, char*[]); +int profile_convnd_bwd_data(int, char*[], int); +int profile_conv_bwd_weight(int, char*[]); +int profile_convnd_bwd_weight(int, char*[], int); +int profile_normalization(int, char*[]); +int profile_reduce(int, char*[]); + +static void print_helper_message() +{ + // clang-format off + printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_splitk: Split-K GEMM\n" + " gemm_bilinear: GEMM+Bilinear\n" + " gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n" + " gemm_reduce: GEMM+Reduce\n" + " gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n" + " batched_gemm: Batched GEMM\n" + " batched_gemm_reduce: Batched GEMM+Reduce\n" + " grouped_gemm: Grouped GEMM\n" + " conv_fwd: ForwardConvolution\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" + " conv1d_bwd_data: BackwardConvolution data 1 dim\n" + " conv2d_bwd_data: BackwardConvolution data 2 dim\n" + " conv3d_bwd_data: BackwardConvolution data 3 dim\n" + " conv2d_bwd_weight: Backward Weight Convolution 2d\n" + " reduce: Reduce\n"); + // clang-format on +} + +int main(int argc, char* argv[]) +{ + if(argc == 1) + { + print_helper_message(); + + return 0; + } + + if(strcmp(argv[1], "gemm") == 0) + { + return profile_gemm(argc, argv); + } + else if(strcmp(argv[1], "gemm_splitk") == 0) + { + return profile_gemm_splitk(argc, argv); + } + else if(strcmp(argv[1], "gemm_bilinear") == 0) + { + return profile_gemm_bilinear(argc, argv); + } + else if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0) + { + return profile_gemm_add_add_fastgelu(argc, argv); + } + else if(strcmp(argv[1], "gemm_reduce") == 0) + { + return profile_gemm_reduce(argc, argv); + } + else if(strcmp(argv[1], "gemm_bias_add_reduce") == 0) + { + return profile_gemm_bias_add_reduce(argc, argv); + } + else if(strcmp(argv[1], "batched_gemm") == 0) + { + return profile_batched_gemm(argc, argv); + } + else if(strcmp(argv[1], "batched_gemm_reduce") == 0) + { + return profile_batched_gemm_reduce(argc, argv); + } + else if(strcmp(argv[1], "grouped_gemm") == 0) + { + return profile_grouped_gemm(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd") == 0) + { + return profile_conv_fwd(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) + { + return profile_conv_fwd_bias_relu(argc, argv); + } + else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) + { + return profile_conv_fwd_bias_relu_add(argc, argv); + } + else if(strcmp(argv[1], "conv1d_bwd_data") == 0) + { + return profile_convnd_bwd_data(argc, argv, 1); + } + else if(strcmp(argv[1], "conv2d_bwd_data") == 0) + { + return profile_convnd_bwd_data(argc, argv, 2); + } + else if(strcmp(argv[1], "conv3d_bwd_data") == 0) + { + return profile_convnd_bwd_data(argc, argv, 3); + } + else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) + { + return profile_conv_bwd_weight(argc, argv); + } + else if(strcmp(argv[1], "convnd1d_bwd_weight") == 0) + { + return profile_convnd_bwd_weight(argc, argv, 1); + } + else if(strcmp(argv[1], "convnd2d_bwd_weight") == 0) + { + return profile_convnd_bwd_weight(argc, argv, 2); + } + else if(strcmp(argv[1], "convnd3d_bwd_weight") == 0) + { + return profile_convnd_bwd_weight(argc, argv, 3); + } + else if(strcmp(argv[1], "reduce") == 0) + { + return profile_reduce(argc, argv); + } + else if(strcmp(argv[1], "batchnorm") == 0 || strcmp(argv[1], "layernorm") == 0 || + strcmp(argv[1], "softmax") == 0) + { + return profile_normalization(argc, argv); + } + else + { + print_helper_message(); + + return 0; + } +} diff --git a/rbuild.ini b/rbuild.ini new file mode 100644 index 00000000..3649cedf --- /dev/null +++ b/rbuild.ini @@ -0,0 +1,8 @@ +[develop] +cxx = ${rocm_path}/bin/hipcc +cc = ${rocm_path}/llvm/bin/clang +ignore = pcre +deps = + -f dev-requirements.txt +define = + BUILD_DEV=On diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..b91bf2e5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +danmar/cppcheck@dd05839a7e63ef04afd34711cb3e1e0ef742882f diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh new file mode 100644 index 00000000..71df7d10 --- /dev/null +++ b/script/clang-format-overwrite.sh @@ -0,0 +1,2 @@ +#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index ebfa2b9f..86b62368 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -3,16 +3,18 @@ rm -f CMakeCache.txt rm -f *.cmake rm -rf CMakeFiles -MY_PROJECT_SOURCE=../../.. +MY_PROJECT_SOURCE=../ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ --D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \ --D BUILD_DEV=ON \ +-D BUILD_DEV=OFF \ -D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS=" -O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} + +#-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +#-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ diff --git a/script/conv2d_fwd.sh b/script/conv2d_fwd.sh new file mode 100755 index 00000000..acc91e19 --- /dev/null +++ b/script/conv2d_fwd.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +# test +######## verify init repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ + $DRIVER $VERIFY $INIT $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 128 256 64 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 256 64 3 7 7 230 230 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 256 64 3 7 7 224 224 2 2 1 1 3 3 3 3 + + N=$5 + +# Resnet50 +######## verify init repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/conv_driver.sh b/script/conv_driver.sh new file mode 100755 index 00000000..8805e0cc --- /dev/null +++ b/script/conv_driver.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j conv_fwd_driver_offline +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline + + DRIVER="./host/driver_offline/conv_fwd_driver_offline" +#DRIVER="./host/driver_offline/conv_bwd_driver_offline" +#DRIVER="./host/driver_offline/conv_wrw_driver_offline" + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + + DESIRED_GRID_SIZE=$7 + +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 3 3 1 1 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + +# Resnet50 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/count_vgpr.sh b/script/count_vgpr.sh index 4fbfec02..07debc53 100755 --- a/script/count_vgpr.sh +++ b/script/count_vgpr.sh @@ -1,259 +1,20 @@ #!/bin/bash FILE=$1 -echo v0 $( grep -w v0 $FILE | wc -l ) -echo v1 $( grep -w v1 $FILE | wc -l ) -echo v2 $( grep -w v2 $FILE | wc -l ) -echo v3 $( grep -w v3 $FILE | wc -l ) -echo v4 $( grep -w v4 $FILE | wc -l ) -echo v5 $( grep -w v5 $FILE | wc -l ) -echo v6 $( grep -w v6 $FILE | wc -l ) -echo v7 $( grep -w v7 $FILE | wc -l ) -echo v8 $( grep -w v8 $FILE | wc -l ) -echo v9 $( grep -w v9 $FILE | wc -l ) -echo v10 $( grep -w v10 $FILE | wc -l ) -echo v11 $( grep -w v11 $FILE | wc -l ) -echo v12 $( grep -w v12 $FILE | wc -l ) -echo v13 $( grep -w v13 $FILE | wc -l ) -echo v14 $( grep -w v14 $FILE | wc -l ) -echo v15 $( grep -w v15 $FILE | wc -l ) -echo v16 $( grep -w v16 $FILE | wc -l ) -echo v17 $( grep -w v17 $FILE | wc -l ) -echo v18 $( grep -w v18 $FILE | wc -l ) -echo v19 $( grep -w v19 $FILE | wc -l ) -echo v20 $( grep -w v20 $FILE | wc -l ) -echo v21 $( grep -w v21 $FILE | wc -l ) -echo v22 $( grep -w v22 $FILE | wc -l ) -echo v23 $( grep -w v23 $FILE | wc -l ) -echo v24 $( grep -w v24 $FILE | wc -l ) -echo v25 $( grep -w v25 $FILE | wc -l ) -echo v26 $( grep -w v26 $FILE | wc -l ) -echo v27 $( grep -w v27 $FILE | wc -l ) -echo v28 $( grep -w v28 $FILE | wc -l ) -echo v29 $( grep -w v29 $FILE | wc -l ) -echo v30 $( grep -w v30 $FILE | wc -l ) -echo v31 $( grep -w v31 $FILE | wc -l ) -echo v32 $( grep -w v32 $FILE | wc -l ) -echo v33 $( grep -w v33 $FILE | wc -l ) -echo v34 $( grep -w v34 $FILE | wc -l ) -echo v35 $( grep -w v35 $FILE | wc -l ) -echo v36 $( grep -w v36 $FILE | wc -l ) -echo v37 $( grep -w v37 $FILE | wc -l ) -echo v38 $( grep -w v38 $FILE | wc -l ) -echo v39 $( grep -w v39 $FILE | wc -l ) -echo v40 $( grep -w v40 $FILE | wc -l ) -echo v41 $( grep -w v41 $FILE | wc -l ) -echo v42 $( grep -w v42 $FILE | wc -l ) -echo v43 $( grep -w v43 $FILE | wc -l ) -echo v44 $( grep -w v44 $FILE | wc -l ) -echo v45 $( grep -w v45 $FILE | wc -l ) -echo v46 $( grep -w v46 $FILE | wc -l ) -echo v47 $( grep -w v47 $FILE | wc -l ) -echo v48 $( grep -w v48 $FILE | wc -l ) -echo v49 $( grep -w v49 $FILE | wc -l ) -echo v50 $( grep -w v50 $FILE | wc -l ) -echo v51 $( grep -w v51 $FILE | wc -l ) -echo v52 $( grep -w v52 $FILE | wc -l ) -echo v53 $( grep -w v53 $FILE | wc -l ) -echo v54 $( grep -w v54 $FILE | wc -l ) -echo v55 $( grep -w v55 $FILE | wc -l ) -echo v56 $( grep -w v56 $FILE | wc -l ) -echo v57 $( grep -w v57 $FILE | wc -l ) -echo v58 $( grep -w v58 $FILE | wc -l ) -echo v59 $( grep -w v59 $FILE | wc -l ) -echo v60 $( grep -w v60 $FILE | wc -l ) -echo v61 $( grep -w v61 $FILE | wc -l ) -echo v62 $( grep -w v62 $FILE | wc -l ) -echo v63 $( grep -w v63 $FILE | wc -l ) -echo v64 $( grep -w v64 $FILE | wc -l ) -echo v65 $( grep -w v65 $FILE | wc -l ) -echo v66 $( grep -w v66 $FILE | wc -l ) -echo v67 $( grep -w v67 $FILE | wc -l ) -echo v68 $( grep -w v68 $FILE | wc -l ) -echo v69 $( grep -w v69 $FILE | wc -l ) -echo v70 $( grep -w v70 $FILE | wc -l ) -echo v71 $( grep -w v71 $FILE | wc -l ) -echo v72 $( grep -w v72 $FILE | wc -l ) -echo v73 $( grep -w v73 $FILE | wc -l ) -echo v74 $( grep -w v74 $FILE | wc -l ) -echo v75 $( grep -w v75 $FILE | wc -l ) -echo v76 $( grep -w v76 $FILE | wc -l ) -echo v77 $( grep -w v77 $FILE | wc -l ) -echo v78 $( grep -w v78 $FILE | wc -l ) -echo v79 $( grep -w v79 $FILE | wc -l ) -echo v80 $( grep -w v80 $FILE | wc -l ) -echo v81 $( grep -w v81 $FILE | wc -l ) -echo v82 $( grep -w v82 $FILE | wc -l ) -echo v83 $( grep -w v83 $FILE | wc -l ) -echo v84 $( grep -w v84 $FILE | wc -l ) -echo v85 $( grep -w v85 $FILE | wc -l ) -echo v86 $( grep -w v86 $FILE | wc -l ) -echo v87 $( grep -w v87 $FILE | wc -l ) -echo v88 $( grep -w v88 $FILE | wc -l ) -echo v89 $( grep -w v89 $FILE | wc -l ) -echo v90 $( grep -w v90 $FILE | wc -l ) -echo v91 $( grep -w v91 $FILE | wc -l ) -echo v92 $( grep -w v92 $FILE | wc -l ) -echo v93 $( grep -w v93 $FILE | wc -l ) -echo v94 $( grep -w v94 $FILE | wc -l ) -echo v95 $( grep -w v95 $FILE | wc -l ) -echo v96 $( grep -w v96 $FILE | wc -l ) -echo v97 $( grep -w v97 $FILE | wc -l ) -echo v98 $( grep -w v98 $FILE | wc -l ) -echo v99 $( grep -w v99 $FILE | wc -l ) -echo v100 $( grep -w v100 $FILE | wc -l ) -echo v101 $( grep -w v101 $FILE | wc -l ) -echo v102 $( grep -w v102 $FILE | wc -l ) -echo v103 $( grep -w v103 $FILE | wc -l ) -echo v104 $( grep -w v104 $FILE | wc -l ) -echo v105 $( grep -w v105 $FILE | wc -l ) -echo v106 $( grep -w v106 $FILE | wc -l ) -echo v107 $( grep -w v107 $FILE | wc -l ) -echo v108 $( grep -w v108 $FILE | wc -l ) -echo v109 $( grep -w v109 $FILE | wc -l ) -echo v110 $( grep -w v110 $FILE | wc -l ) -echo v111 $( grep -w v111 $FILE | wc -l ) -echo v112 $( grep -w v112 $FILE | wc -l ) -echo v113 $( grep -w v113 $FILE | wc -l ) -echo v114 $( grep -w v114 $FILE | wc -l ) -echo v115 $( grep -w v115 $FILE | wc -l ) -echo v116 $( grep -w v116 $FILE | wc -l ) -echo v117 $( grep -w v117 $FILE | wc -l ) -echo v118 $( grep -w v118 $FILE | wc -l ) -echo v119 $( grep -w v119 $FILE | wc -l ) -echo v120 $( grep -w v120 $FILE | wc -l ) -echo v121 $( grep -w v121 $FILE | wc -l ) -echo v122 $( grep -w v122 $FILE | wc -l ) -echo v123 $( grep -w v123 $FILE | wc -l ) -echo v124 $( grep -w v124 $FILE | wc -l ) -echo v125 $( grep -w v125 $FILE | wc -l ) -echo v126 $( grep -w v126 $FILE | wc -l ) -echo v127 $( grep -w v127 $FILE | wc -l ) -echo v128 $( grep -w v128 $FILE | wc -l ) -echo v129 $( grep -w v129 $FILE | wc -l ) -echo v130 $( grep -w v130 $FILE | wc -l ) -echo v131 $( grep -w v131 $FILE | wc -l ) -echo v132 $( grep -w v132 $FILE | wc -l ) -echo v133 $( grep -w v133 $FILE | wc -l ) -echo v134 $( grep -w v134 $FILE | wc -l ) -echo v135 $( grep -w v135 $FILE | wc -l ) -echo v136 $( grep -w v136 $FILE | wc -l ) -echo v137 $( grep -w v137 $FILE | wc -l ) -echo v138 $( grep -w v138 $FILE | wc -l ) -echo v139 $( grep -w v139 $FILE | wc -l ) -echo v140 $( grep -w v140 $FILE | wc -l ) -echo v141 $( grep -w v141 $FILE | wc -l ) -echo v142 $( grep -w v142 $FILE | wc -l ) -echo v143 $( grep -w v143 $FILE | wc -l ) -echo v144 $( grep -w v144 $FILE | wc -l ) -echo v145 $( grep -w v145 $FILE | wc -l ) -echo v146 $( grep -w v146 $FILE | wc -l ) -echo v147 $( grep -w v147 $FILE | wc -l ) -echo v148 $( grep -w v148 $FILE | wc -l ) -echo v149 $( grep -w v149 $FILE | wc -l ) -echo v150 $( grep -w v150 $FILE | wc -l ) -echo v151 $( grep -w v151 $FILE | wc -l ) -echo v152 $( grep -w v152 $FILE | wc -l ) -echo v153 $( grep -w v153 $FILE | wc -l ) -echo v154 $( grep -w v154 $FILE | wc -l ) -echo v155 $( grep -w v155 $FILE | wc -l ) -echo v156 $( grep -w v156 $FILE | wc -l ) -echo v157 $( grep -w v157 $FILE | wc -l ) -echo v158 $( grep -w v158 $FILE | wc -l ) -echo v159 $( grep -w v159 $FILE | wc -l ) -echo v160 $( grep -w v160 $FILE | wc -l ) -echo v161 $( grep -w v161 $FILE | wc -l ) -echo v162 $( grep -w v162 $FILE | wc -l ) -echo v163 $( grep -w v163 $FILE | wc -l ) -echo v164 $( grep -w v164 $FILE | wc -l ) -echo v165 $( grep -w v165 $FILE | wc -l ) -echo v166 $( grep -w v166 $FILE | wc -l ) -echo v167 $( grep -w v167 $FILE | wc -l ) -echo v168 $( grep -w v168 $FILE | wc -l ) -echo v169 $( grep -w v169 $FILE | wc -l ) -echo v170 $( grep -w v170 $FILE | wc -l ) -echo v171 $( grep -w v171 $FILE | wc -l ) -echo v172 $( grep -w v172 $FILE | wc -l ) -echo v173 $( grep -w v173 $FILE | wc -l ) -echo v174 $( grep -w v174 $FILE | wc -l ) -echo v175 $( grep -w v175 $FILE | wc -l ) -echo v176 $( grep -w v176 $FILE | wc -l ) -echo v177 $( grep -w v177 $FILE | wc -l ) -echo v178 $( grep -w v178 $FILE | wc -l ) -echo v179 $( grep -w v179 $FILE | wc -l ) -echo v180 $( grep -w v180 $FILE | wc -l ) -echo v181 $( grep -w v181 $FILE | wc -l ) -echo v182 $( grep -w v182 $FILE | wc -l ) -echo v183 $( grep -w v183 $FILE | wc -l ) -echo v184 $( grep -w v184 $FILE | wc -l ) -echo v185 $( grep -w v185 $FILE | wc -l ) -echo v186 $( grep -w v186 $FILE | wc -l ) -echo v187 $( grep -w v187 $FILE | wc -l ) -echo v188 $( grep -w v188 $FILE | wc -l ) -echo v189 $( grep -w v189 $FILE | wc -l ) -echo v190 $( grep -w v190 $FILE | wc -l ) -echo v191 $( grep -w v191 $FILE | wc -l ) -echo v192 $( grep -w v192 $FILE | wc -l ) -echo v193 $( grep -w v193 $FILE | wc -l ) -echo v194 $( grep -w v194 $FILE | wc -l ) -echo v195 $( grep -w v195 $FILE | wc -l ) -echo v196 $( grep -w v196 $FILE | wc -l ) -echo v197 $( grep -w v197 $FILE | wc -l ) -echo v198 $( grep -w v198 $FILE | wc -l ) -echo v199 $( grep -w v199 $FILE | wc -l ) -echo v200 $( grep -w v200 $FILE | wc -l ) -echo v201 $( grep -w v201 $FILE | wc -l ) -echo v202 $( grep -w v202 $FILE | wc -l ) -echo v203 $( grep -w v203 $FILE | wc -l ) -echo v204 $( grep -w v204 $FILE | wc -l ) -echo v205 $( grep -w v205 $FILE | wc -l ) -echo v206 $( grep -w v206 $FILE | wc -l ) -echo v207 $( grep -w v207 $FILE | wc -l ) -echo v208 $( grep -w v208 $FILE | wc -l ) -echo v209 $( grep -w v209 $FILE | wc -l ) -echo v210 $( grep -w v210 $FILE | wc -l ) -echo v211 $( grep -w v211 $FILE | wc -l ) -echo v212 $( grep -w v212 $FILE | wc -l ) -echo v213 $( grep -w v213 $FILE | wc -l ) -echo v214 $( grep -w v214 $FILE | wc -l ) -echo v215 $( grep -w v215 $FILE | wc -l ) -echo v216 $( grep -w v216 $FILE | wc -l ) -echo v217 $( grep -w v217 $FILE | wc -l ) -echo v218 $( grep -w v218 $FILE | wc -l ) -echo v219 $( grep -w v219 $FILE | wc -l ) -echo v220 $( grep -w v220 $FILE | wc -l ) -echo v221 $( grep -w v221 $FILE | wc -l ) -echo v222 $( grep -w v222 $FILE | wc -l ) -echo v223 $( grep -w v223 $FILE | wc -l ) -echo v224 $( grep -w v224 $FILE | wc -l ) -echo v225 $( grep -w v225 $FILE | wc -l ) -echo v226 $( grep -w v226 $FILE | wc -l ) -echo v227 $( grep -w v227 $FILE | wc -l ) -echo v228 $( grep -w v228 $FILE | wc -l ) -echo v229 $( grep -w v229 $FILE | wc -l ) -echo v230 $( grep -w v230 $FILE | wc -l ) -echo v231 $( grep -w v231 $FILE | wc -l ) -echo v232 $( grep -w v232 $FILE | wc -l ) -echo v233 $( grep -w v233 $FILE | wc -l ) -echo v234 $( grep -w v234 $FILE | wc -l ) -echo v235 $( grep -w v235 $FILE | wc -l ) -echo v236 $( grep -w v236 $FILE | wc -l ) -echo v237 $( grep -w v237 $FILE | wc -l ) -echo v238 $( grep -w v238 $FILE | wc -l ) -echo v239 $( grep -w v239 $FILE | wc -l ) -echo v240 $( grep -w v240 $FILE | wc -l ) -echo v241 $( grep -w v241 $FILE | wc -l ) -echo v242 $( grep -w v242 $FILE | wc -l ) -echo v243 $( grep -w v243 $FILE | wc -l ) -echo v244 $( grep -w v244 $FILE | wc -l ) -echo v245 $( grep -w v245 $FILE | wc -l ) -echo v246 $( grep -w v246 $FILE | wc -l ) -echo v247 $( grep -w v247 $FILE | wc -l ) -echo v248 $( grep -w v248 $FILE | wc -l ) -echo v249 $( grep -w v249 $FILE | wc -l ) -echo v250 $( grep -w v250 $FILE | wc -l ) -echo v251 $( grep -w v251 $FILE | wc -l ) -echo v252 $( grep -w v252 $FILE | wc -l ) -echo v253 $( grep -w v253 $FILE | wc -l ) -echo v254 $( grep -w v254 $FILE | wc -l ) -echo v255 $( grep -w v255 $FILE | wc -l ) +for num in {0..255} +do + base_pattern="(\[?${num}\b|\[\d*:${num}\])" + spattern="s${base_pattern}" + vpattern="v${base_pattern}" + apattern="a${base_pattern}" + scount=$(grep -P $spattern $FILE | wc -l) + vcount=$(grep -P $vpattern $FILE | wc -l) + acount=$(grep -P $apattern $FILE | wc -l) + bash -c "echo -n v${num} $vcount && \ + echo -n , s${num} $scount && \ + echo -n , a${num} $acount" + if [[ $scount -ne 0 || $vcount -ne 0 || $acount -ne 0 ]]; then + echo -n " *" + fi + echo "" +done diff --git a/script/docker-rocm4.1.sh b/script/docker-rocm4.1.sh deleted file mode 100755 index 61cc33c5..00000000 --- a/script/docker-rocm4.1.sh +++ /dev/null @@ -1,14 +0,0 @@ -WORKSPACE=$1 -echo "workspace: " $WORKSPACE - -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v $WORKSPACE:/root/workspace \ -rocm/tensorflow:rocm4.1-tf1.15-dev \ -/bin/bash - -#--network host \ diff --git a/script/docker-rocm4.3.1.sh b/script/docker-rocm4.3.1.sh deleted file mode 100755 index 48cb675b..00000000 --- a/script/docker-rocm4.3.1.sh +++ /dev/null @@ -1,14 +0,0 @@ -WORKSPACE=$1 -echo "workspace: " $WORKSPACE - -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v $WORKSPACE:/root/workspace \ -rocm/tensorflow:rocm4.3.1-tf2.6-dev \ -/bin/bash - -#--network host \ diff --git a/script/example_gemm_xdl.sh b/script/example_gemm_xdl.sh new file mode 100755 index 00000000..9e2d77d3 --- /dev/null +++ b/script/example_gemm_xdl.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=1 + + make -j gemm_xdl + + DRIVER="./example/gemm_xdl" + +VERIFY=$1 +INIT=$2 +LOG=$3 +REPEAT=$4 + +######### verify init log repeat M___ N___ K___ StrideA StrideB StrideC +#$DRIVER $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048 + $DRIVER $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192 diff --git a/script/gemm.sh b/script/gemm.sh new file mode 100755 index 00000000..395db86d --- /dev/null +++ b/script/gemm.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +######## verify init repeat M___ N___ K___ StrideA StrideB StrideC StrideC1 +#$DRIVER $VERIFY $INIT $REPEAT 256 256 256 256 256 256 256 +#$DRIVER $VERIFY $INIT $REPEAT 960 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $REPEAT 1920 2048 2048 2048 2048 2048 2048 + $DRIVER $VERIFY $INIT $REPEAT 3840 4096 4096 4096 4096 4096 4096 +#$DRIVER $VERIFY $INIT $REPEAT 7680 8192 8192 8192 8192 8192 8192 +#$DRIVER $VERIFY $INIT $REPEAT 1024 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $REPEAT 2048 2048 2048 2048 2048 2048 2048 diff --git a/script/gemm_driver.sh b/script/gemm_driver.sh new file mode 100755 index 00000000..491c14cc --- /dev/null +++ b/script/gemm_driver.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j gemm_driver_offline + + DRIVER="./host/driver_offline/gemm_driver_offline" + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + + M01=$7 + N01=$8 + +######### layout algo verify init log repeat M___ N___ K___ M01_ N01_ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 diff --git a/script/parse_perf_data.py b/script/parse_perf_data.py new file mode 100644 index 00000000..4cb13e62 --- /dev/null +++ b/script/parse_perf_data.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +import os, io, argparse, datetime, re +import numpy as np +import sqlalchemy +from sqlalchemy.types import NVARCHAR, Float, Integer +import pymysql +import pandas as pd +from sshtunnel import SSHTunnelForwarder + +def print_to_string(*args, **kwargs): + output = io.StringIO() + print(*args, file=output, **kwargs) + contents = output.getvalue() + output.close() + return contents + +def parse_args(): + parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs') + parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files') + args = parser.parse_args() + files = [] + if os.path.isdir(args.filename): + all_files = os.listdir(args.filename) + for name in all_files: + if not 'log' in name: + continue + files.append(os.path.join(args.filename, name)) + else: + files = [args.filename] + args.files = files + return args + +def main(): + args = parse_args() + tests = [] + kernels=[] + tflops=[] + dtype=[] + alayout=[] + blayout=[] + M=[] + N=[] + K=[] + StrideA=[] + StrideB=[] + StrideC=[] + #parse results, get the Tflops value for "Best Perf" kernels + + glue="" + for filename in args.files: + for line in open(filename): + if 'Branch name' in line: + lst=line.split() + branch_name=lst[2] + if 'On branch' in line: + lst=line.split() + branch_name=lst[2] + if 'Node name' in line: + lst=line.split() + node_id=lst[2] + if 'GPU_arch' in line: + lst=line.split() + gpu_arch=lst[2] + if 'HIP version' in line: + lst=line.split() + hip_vers=lst[2] + if 'Compute Unit' in line: + lst=line.split() + compute_units=lst[2] + if 'InstalledDir' in line: + lst=line.split() + rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')] + print("Branch name:",branch_name) + print("Node name:",node_id) + print("GPU_arch:",gpu_arch) + print("Compute units:",compute_units) + print("ROCM_version:",rocm_vers) + print("HIP_version:",hip_vers) + + + #parse gemm performance tests: + if 'gemm' in filename: + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + if len(lst)>=37: #the line is complete + tests.append(glue.join(lst[5:30])) + kernels.append(glue.join(lst[37:])) + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + elif len(lst)<37 and len(lst)>=33: #the tflops are available + tests.append(glue.join(lst[5:30])) + kernels.append("N/A") + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + print("warning: incomplete line:",lst) + elif len(lst)<33: #even the tflops are not available + print("Error in ckProfiler output!") + print("warning: incomplete line=",lst) + #sort results + #sorted_tests = sorted(tests) + #print("sorted tests:",sorted_tests) + sorted_tflops = [x for _,x in sorted(zip(tests,tflops))] + #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] + test_list=list(range(1,len(tests)+1)) + + #parse resnet50 performance tests: + if 'resnet50' in filename: + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + tflops.append(lst[4]) + + print("Number of tests:",len(tflops)) + sql_hostname = '127.0.0.1' + sql_username = os.environ["dbuser"] + sql_password = os.environ["dbpassword"] + sql_main_database = 'miopen_perf' + sql_port = 3306 + ssh_host = os.environ["dbsship"] + ssh_user = os.environ["dbsshuser"] + ssh_port = int(os.environ["dbsshport"]) + ssh_pass = os.environ["dbsshpassword"] + + with SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_password=ssh_pass, + remote_bind_address=(sql_hostname, sql_port)) as tunnel: + + sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'. + format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database)) + conn = sqlEngine.connect() + + #save gemm performance tests: + if 'gemm' in filename: + + #write the ck_gemm_test_params table + #only needed once the test set changes + ''' + sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))] + sorted_alayout = [x for _,x in sorted(zip(tests,alayout))] + sorted_blayout = [x for _,x in sorted(zip(tests,blayout))] + sorted_M = [x for _,x in sorted(zip(tests,M))] + sorted_N = [x for _,x in sorted(zip(tests,N))] + sorted_K = [x for _,x in sorted(zip(tests,K))] + sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))] + sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))] + sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))] + ck_gemm_params=[test_list,sorted_dtypes,sorted_alayout,sorted_blayout, + sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB, + sorted_StrideC] + df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type', + 'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC']) + print(df) + + dtypes = { + 'Test_number': Integer(), + 'Data_type': NVARCHAR(length=5), + 'Alayout': NVARCHAR(length=12), + 'Blayout': NVARCHAR(length=12), + 'M': Integer(), + 'N': Integer(), + 'K': Integer(), + 'StrideA': Integer(), + 'StrideB': Integer(), + 'StrideC': Integer() + } + df.to_sql("ck_gemm_test_params",conn,if_exists='replace',index=False, dtype=dtypes) + ''' + + #read baseline results for the latest develop branch + query = '''SELECT * from ck_gemm_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_gemm_tflops where Branch_ID='develop' );''' + tflops_base = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,len(tests)+1): + testlist.append("Test%i"%i) + ck_gemm_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] + flops=pd.DataFrame(data=[ck_gemm_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) + df_add=pd.DataFrame(data=[sorted_tflops],columns=testlist) + flops=pd.concat([flops,df_add],axis=1) + print("new tflops for gemm tests:",flops) + flops.to_sql("ck_gemm_tflops",conn,if_exists='append',index=False) + + #save resnet50 performance tests: + if 'resnet50' in filename: + #read baseline results for the latest develop branch + query = '''SELECT * from ck_resnet50_N256_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N256_tflops where Branch_ID='develop' );''' + tflops_base_N256 = pd.read_sql_query(query, conn) + query = '''SELECT * from ck_resnet50_N4_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N4_tflops where Branch_ID='develop' );''' + tflops_base_N4 = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,50): + testlist.append("Layer%i"%i) + ck_resnet_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] + flops0=pd.DataFrame(data=[ck_resnet_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) + df_add=pd.DataFrame(data=[tflops[0:49]],columns=testlist) + flops=pd.concat([flops0,df_add],axis=1) + print("new tflops for N=256 resnet50 test:",flops) + flops.to_sql("ck_resnet50_N256_tflops",conn,if_exists='append',index=False) + df_add=pd.DataFrame(data=[tflops[49:98]],columns=testlist) + flops=pd.concat([flops0,df_add],axis=1) + print("new tflops for N=4 resnet50 test:",flops) + flops.to_sql("ck_resnet50_N4_tflops",conn,if_exists='append',index=False) + + conn.close() + + #compare the results to the baseline if baseline exists + regression=0 + if 'gemm' in filename: + if not tflops_base.empty: + base=tflops_base[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(sorted_tflops[i]): + print("test # ",i,"shows regression by {:.3f}%".format( + (float(sorted_tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(sorted_tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline") + if 'resnet50' in filename: + if not tflops_base_N256.empty: + base=tflops_base_N256[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(tflops[i]): + print("layer # ",i,"shows regression by {:.3f}%".format( + (float(tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline for N=256") + if not tflops_base_N4.empty: + base=tflops_base_N4[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(tflops[i+49]): + print("layer # ",i,"shows regression by {:.3f}%".format( + (float(tflops[i+49])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(tflops[i+49])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline for N=4") + + #return 0 if performance criteria met, otherwise return 1 + return regression + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/script/pool2d_fwd.sh b/script/pool2d_fwd.sh new file mode 100755 index 00000000..10acf539 --- /dev/null +++ b/script/pool2d_fwd.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +# test +######## verify init repeat N__ C___ Y X Hi__ Wi__ Strides LeftPads RightPads +#$DRIVER $VERIFY $INIT $REPEAT 128 192 3 3 71 71 2 2 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT 128 64 1 1 1 1 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT 256 3 7 7 230 230 2 2 0 0 0 0 + $DRIVER $VERIFY $INIT $REPEAT 256 1024 14 14 14 14 1 1 0 0 0 0 + + N=$5 + +# Resnet50 +######## verify init repeat N__ C___ Y X Hi__ Wi__ Strides LeftPads RightPads +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 3 3 28 28 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 3 3 58 58 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 1 1 7 7 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 3 3 14 14 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 3 3 30 30 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 3 3 16 16 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 7 7 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 3 3 7 7 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 3 3 56 56 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 3 7 7 230 230 2 2 0 0 0 0 diff --git a/script/profile_conv.sh b/script/profile_conv.sh new file mode 100755 index 00000000..c3ba39c9 --- /dev/null +++ b/script/profile_conv.sh @@ -0,0 +1,177 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + +# make -j ckProfiler + + DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +IN_LAYOUT=$3 +WEI_LAYOUT=$4 +OUT_LAYOUT=$5 +VERIFY=$6 +INIT=$7 +LOG=$8 +REPEAT=$9 + +# test +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 28 28 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + + N=${10} + +# Resnet50 (no duplicated layer) +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 + + +# Resnet50 fusion +####### op_________________ datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C_ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + + +# Resnet50 +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 230 230 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE + +# SSD +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 3 7 7 300 300 2 2 1 1 3 3 3 3 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 64 64 3 3 75 75 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 64 1 1 75 75 2 2 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 64 3 3 75 75 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 1 1 38 38 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 256 1 1 38 38 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 512 256 3 3 38 38 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 512 1 1 19 19 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 512 256 3 3 19 19 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 512 1 1 10 10 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 10 10 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 256 1 1 5 5 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 5 5 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 128 256 1 1 3 3 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 256 128 3 3 3 3 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 340 256 3 3 38 38 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 512 3 3 19 19 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 512 3 3 10 10 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 256 3 3 5 5 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 340 256 3 3 3 3 1 1 1 1 1 1 1 1 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh new file mode 100755 index 00000000..b816c510 --- /dev/null +++ b/script/profile_gemm.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +#make -j ckProfiler +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +REPEAT=$7 + +######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 256 256 256 256 256 256 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 6656 8192 8192 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3328 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1664 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 832 1024 1024 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7040 8192 8192 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 5120 5632 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2560 2816 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1280 1408 1024 -1 -1 -1 diff --git a/script/profile_reduce_no_index.sh b/script/profile_reduce_no_index.sh new file mode 100755 index 00000000..580a7ca1 --- /dev/null +++ b/script/profile_reduce_no_index.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +PRECISION= +##PRECISION=--half +##PRECISION=--double +##PRECISION=--int8 +##PRECISION=--bf16 + +if [ -n $PRECISION ] && [ "$PRECISION" = "--half" -o "$PRECISION" = "--bf16" ]; then + ACCTYPE="-C 1" +elif [ -n $PRECISION ] && [ "$PRECISION" = "--int8" ]; then + ACCTYPE="-C 2" +fi + + +driver="./bin/ckProfiler" + +VERIFY="-v $1" +INIT=$2 +NREPEAT=$3 + + +#### 0 - ADD, 5 - AVG, 7 - NORM2 +Operations="0 5 7" + +#### 0 - ADD, 5 - AVG, for int8, no NORM2 supported +if [ -n $PRECISION ] && [ "$PRECISION" = "--int8" ]; then + Operations=5 +fi + +## for generic validation +for op in $Operations; do + set -x + ####### datatype layout reduce dims op acctype verify init repeats + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + set +x +done + +#### 0 - ADD, 5 - AVG, 7 - NORM2 +Operations=5 + +## for performance evaluation (resnet50 NHWC => C) +for op in $Operations; do + set -x + ####### datatype layout reduce dims op acctype verify init repeats + $driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + set +x +done + diff --git a/script/profile_reduce_with_index.sh b/script/profile_reduce_with_index.sh new file mode 100755 index 00000000..d4671e39 --- /dev/null +++ b/script/profile_reduce_with_index.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +PRECISION= +##PRECISION=--half +##PRECISION=--double +##PRECISION=--int8 +##PRECISION=--bf16 + +driver="./bin/ckProfiler" + +VERIFY="-v $1" +INIT=$2 +NREPEAT=$3 + +#### 2 - MIN, 3 - MAX, 4 - AMAX +Operations="2 4" + +## for generic validation +for op in $Operations; do + for use_idx in 0 1; do + set -x + ####### datatype layout reduce dims op use index verify init repeats + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,22960 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 4,1469440 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + set +x + done +done + +Operations=2 + +## for performance evaluation (resnet50 NHWC => C) +for op in $Operations; do + for use_idx in 0 1; do + set -x + ####### datatype layout reduce dims op use index verify init repeats + $driver reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $driver reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + set +x + done +done + diff --git a/script/run.sh b/script/run.sh deleted file mode 100755 index 1ff56b22..00000000 --- a/script/run.sh +++ /dev/null @@ -1,137 +0,0 @@ -#!/bin/bash - -## GPU visibility - export ROCR_VISIBLE_DEVICE=0 - export GPU_DEVICE_ORDINAL=0 - - make -j conv_fwd_driver_offline -#make -j conv_bwd_driver_offline -#make -j conv_wrw_driver_offline -#make -j gemm_driver_offline - -DRIVER="./host/driver_offline/conv_fwd_driver_offline" -LAYOUT=$1 -ALGO=$2 -VERIFY=$3 -INIT=$4 -LOG=$5 -REPEAT=$6 - -#M01=$7 -#N01=$8 - - KBATCH=$7 - -######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 - -######### layout algo verify init log repeat M___ N___ K___ -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 - -# Resnet50 -######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 - -# 256x128x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH - - - -# 128x128x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH - - -# 128x64x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 - -# 64x128x32 c64 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH - -# 64x64x32 c32 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448 diff --git a/script/run_performance_tests.sh b/script/run_performance_tests.sh new file mode 100755 index 00000000..95d63d0f --- /dev/null +++ b/script/run_performance_tests.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ +# and make sure the following python packages are installed in your environment: + +pip3 install --upgrade pip +pip3 install sqlalchemy pymysql pandas sshtunnel + +# you would also need to set up some environment variables in order to +# post your new test results to the database and compare them to the baseline +# please contact Illia.Silin@amd.com for more details +# + +export gemm_log="perf_gemm.log" +rm -f $gemm_log +git status | grep -e 'On branch' > ${gemm_log} +echo -n 'Node name: ' >>${gemm_log}; hostname >> ${gemm_log} +#get GPU_arch and number of compute units from rocminfo +echo -n "GPU_arch: " >> ${gemm_log}; rocminfo | grep "Name:" | grep "gfx" >> ${gemm_log} +rocminfo | grep "Compute Unit:" >> ${gemm_log} +hipcc --version | grep -e 'HIP version' >> ${gemm_log} +/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log} +./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log} +./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a $gemm_log + +python3 parse_perf_data.py ${gemm_log} + +#run resnet50 test +export resnet_log="perf_resnet50.log" +rm -f $resnet_log +git status | grep -e 'On branch' > ${resnet_log} +echo -n 'Node name: '>>${resnet_log}; hostname >>${resnet_log} +#get GPU_arch and number of compute units from rocminfo +echo -n "GPU_arch: " >> ${resnet_log}; rocminfo | grep "Name:" | grep "gfx" >> ${resnet_log} +rocminfo | grep "Compute Unit:" >> ${resnet_log} +hipcc --version | grep -e 'HIP version' >> ${resnet_log} +/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log} +#first run tests with N=256 +./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log} +#then run with N=4 +./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log} +#the script will put the results from N=256 and N=4 runs into separate tables +python3 parse_perf_data.py ${resnet_log} diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh new file mode 100644 index 00000000..1bd7a6b5 --- /dev/null +++ b/script/test_convnd_fwd.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash + +# set -e + +DIM1=False +DIM2=True +DIM3=False +DATE=220317 +GIT_HASH=4e6dfda +LOG_DIR=${DATE}_${GIT_HASH} +SUFFIX=${GIT_HASH} + + +#-------------------------------------------------------------------------- +# Commandline arguments parsing +# like: cmd -key[--key] value +#-------------------------------------------------------------------------- + +POSITIONAL=() +while [[ $# -gt 0 ]] +do +key="$1" + +case $key in + -d1|--d1) + DIM1=True + echo DIM1: "${DIM1}" + shift # past argument + ;; + -d2|--d2) + DIM2=True + echo DIM2: "${DIM2}" + shift # past argument + ;; + -d3|--d3) + DIM3=True + echo DIM3: "${DIM3}" + shift # past argument + ;; + -all|--all) + DIM1=True + DIM2=True + DIM3=True + echo DIM1: "${DIM1}" + echo DIM2: "${DIM2}" + echo DIM3: "${DIM3}" + shift # past argument + ;; + -s|--suffix) + SUFFIX=${SUFFIX}_"$2" + echo SUFFIX: "${SUFFIX}" + shift # past argument + shift # past value + ;; + *) # unknown option + POSITIONAL+=("$1") # save it in an array for later + shift # past argument + ;; +esac +done +set -- "${POSITIONAL[@]}" # restore positional parameters + +#-------------------------------------------------------------------------- + +# NUMACTL="numactl --cpunodebind=1 --membind=1" +NUMACTL= +# ENV_CONF= +GPU=mi100 +PROF_ITER_COUNT=10000 +LOG_DIR_PATH=../log/${LOG_DIR} +set -x + +#------------------------------------------------------------------------------- +# 1D +#------------------------------------------------------------------------------- + +if [[ "${DIM1}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv1d nwc <<<<<<<<<<" + CMD="./../build/bin/test_conv1d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv1d_fwd_nwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 2D +#------------------------------------------------------------------------------- + +if [[ "${DIM2}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv2d nhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv2d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv2d_fwd_nhwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 3D +#------------------------------------------------------------------------------- + +if [[ "${DIM3}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv3d ndhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv3d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv3d_fwd_ndhwc_${SUFFIX}_${GPU}.log + +fi diff --git a/script/test_reduce_no_index.sh b/script/test_reduce_no_index.sh new file mode 100755 index 00000000..b9563038 --- /dev/null +++ b/script/test_reduce_no_index.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +## The following will be used for CI + +set -x + +## for float +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 0 2 + +## for float64 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 6 2 + +## for float16 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 1 2 + +## for int8_t +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 3 2 + +## for bfloat16 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 5 2 + +set +x + diff --git a/script/test_reduce_with_index.sh b/script/test_reduce_with_index.sh new file mode 100755 index 00000000..b0843ba6 --- /dev/null +++ b/script/test_reduce_with_index.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +## The following will be used for CI + +set -x + +## for float +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 0 2 + +## for float64 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 6 2 + +## for float16 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 1 2 + +## for int8_t +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 3 2 + +## for bfloat16 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 5 2 + +set +x + diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 00000000..9bd07495 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,50 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/ +) + +include(googletest) + +add_custom_target(tests) + + +function(add_test_executable TEST_NAME) + message("adding test ${TEST_NAME}") + add_executable(${TEST_NAME} ${ARGN}) + add_test(NAME ${TEST_NAME} COMMAND $ ) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) +endfunction(add_test_executable TEST_NAME) + +include(GoogleTest) + +function(add_gtest_executable TEST_NAME) + message("adding gtest ${TEST_NAME}") + add_executable(${TEST_NAME} ${ARGN}) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) + # suppress gtest warnings + target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(${TEST_NAME} PRIVATE gtest_main) + gtest_discover_tests(${TEST_NAME}) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) +endfunction(add_gtest_executable TEST_NAME) + + +add_subdirectory(magic_number_division) +add_subdirectory(space_filling_curve) +add_subdirectory(conv_util) +add_subdirectory(reference_conv_fwd) +add_subdirectory(gemm) +add_subdirectory(gemm_split_k) +add_subdirectory(gemm_reduce) +add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) +add_subdirectory(grouped_gemm) +add_subdirectory(convnd_fwd) +add_subdirectory(reduce) +add_subdirectory(conv2d_bwd_weight) +add_subdirectory(convnd_bwd_weight) +add_subdirectory(convnd_bwd_data) +add_subdirectory(block_to_ctile_map) +add_subdirectory(softmax) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt new file mode 100644 index 00000000..b70e3aae --- /dev/null +++ b/test/batched_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) +target_link_libraries(test_batched_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) + diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp new file mode 100644 index 00000000..7fc1f24f --- /dev/null +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/include/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 512; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000..fa1a2bf8 --- /dev/null +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE host_tensor) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp new file mode 100644 index 00000000..456d2114 --- /dev/null +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/include/profile_batched_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + int BatchCount = 3; + + bool pass = true; + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N, BatchCount); + + if(pass) + { + std::cout << "test BatchedGEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test BatchedGEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} diff --git a/test/block_to_ctile_map/CMakeLists.txt b/test/block_to_ctile_map/CMakeLists.txt new file mode 100644 index 00000000..97dfbb2b --- /dev/null +++ b/test/block_to_ctile_map/CMakeLists.txt @@ -0,0 +1 @@ +add_gtest_executable(test_block_to_ctile_map test_block_to_ctile_map.cpp) \ No newline at end of file diff --git a/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/test/block_to_ctile_map/test_block_to_ctile_map.cpp new file mode 100644 index 00000000..55d9b59f --- /dev/null +++ b/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +using namespace ck; + +static auto I0 = Number<0>{}; +static auto I1 = Number<1>{}; +static auto I2 = Number<2>{}; + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 tile_map( + c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {0, 1, 1}, + {0, 2, 1}, + {0, 3, 0}, + {1, 0, 1}, + {1, 1, 1}, + {1, 2, 1}, + {1, 3, 0}, + {2, 0, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 3, 0}, + {3, 0, 0}, + {3, 1, 0}, + {3, 2, 0}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 + tile_map(c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == false); +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 512; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 0}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 0}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 0}, + {0, 3, 1}, + {1, 3, 1}, + {2, 3, 1}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0) +{ + const index_t M = 512; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + // clang-format off + std::vector> expected_m0_gridsize_validity = { + {5, 15, false}, + {4, 12, true}, + {3, 18, false}, + {2, 12, true}, + {1, 12, true} + }; + // clang-format on + + for(auto e : expected_m0_gridsize_validity) + { + const index_t M01 = std::get<0>(e); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_EQ(tile_map.CalculateGridSize(c_grid_desc_m_n), std::get<1>(e)); + EXPECT_EQ(tile_map.CheckValidity(c_grid_desc_m_n), std::get<2>(e)); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01Adapt tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 1}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 1}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 1}, + {4, 0, 1}, + {5, 0, 1}, + {4, 1, 1}, + {5, 1, 1}, + {4, 2, 1}, + {5, 2, 1}, + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + const index_t KSplit = 3; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_KSplit_M00_N0_M01Adapt + tile_map(c_grid_desc_m_n, M01, KSplit); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18 * KSplit); + + std::vector> expected_ksplitidx_m0idx_n0idx_valid = { + {0, 0, 0, 1}, {0, 1, 0, 1}, {0, 2, 0, 1}, {0, 3, 0, 1}, {0, 0, 1, 1}, {0, 1, 1, 1}, + {0, 2, 1, 1}, {0, 3, 1, 1}, {0, 0, 2, 1}, {0, 1, 2, 1}, {0, 2, 2, 1}, {0, 3, 2, 1}, + {0, 4, 0, 1}, {0, 5, 0, 1}, {0, 4, 1, 1}, {0, 5, 1, 1}, {0, 4, 2, 1}, {0, 5, 2, 1}, + {1, 0, 0, 1}, {1, 1, 0, 1}, {1, 2, 0, 1}, {1, 3, 0, 1}, {1, 0, 1, 1}, {1, 1, 1, 1}, + {1, 2, 1, 1}, {1, 3, 1, 1}, {1, 0, 2, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}, {1, 3, 2, 1}, + {1, 4, 0, 1}, {1, 5, 0, 1}, {1, 4, 1, 1}, {1, 5, 1, 1}, {1, 4, 2, 1}, {1, 5, 2, 1}, + {2, 0, 0, 1}, {2, 1, 0, 1}, {2, 2, 0, 1}, {2, 3, 0, 1}, {2, 0, 1, 1}, {2, 1, 1, 1}, + {2, 2, 1, 1}, {2, 3, 1, 1}, {2, 0, 2, 1}, {2, 1, 2, 1}, {2, 2, 2, 1}, {2, 3, 2, 1}, + {2, 4, 0, 1}, {2, 5, 0, 1}, {2, 4, 1, 1}, {2, 5, 1, 1}, {2, 4, 2, 1}, {2, 5, 2, 1}, + }; + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto ksplitm0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", ksplit, m0, n0 = " << ksplitm0n0_idx[I0] << ", " + << ksplitm0n0_idx[I1] << ", " << ksplitm0n0_idx[I2]; + std::cout << ", valid = " + << tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_ksplitidx_m0idx_n0idx_valid[i] == + std::vector{ksplitm0n0_idx[I0], + ksplitm0n0_idx[I1], + ksplitm0n0_idx[I2], + tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} diff --git a/test/conv2d_bwd_data/CMakeLists.txt b/test/conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 00000000..1b5c03af --- /dev/null +++ b/test/conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_conv2d_bwd_data conv2d_bwd_data.cpp) +target_link_libraries(test_conv2d_bwd_data PRIVATE host_tensor) +target_link_libraries(test_conv2d_bwd_data PRIVATE device_conv2d_bwd_data_instance) diff --git a/test/conv2d_bwd_data/conv2d_bwd_data.cpp b/test/conv2d_bwd_data/conv2d_bwd_data.cpp new file mode 100644 index 00000000..cb924538 --- /dev/null +++ b/test/conv2d_bwd_data/conv2d_bwd_data.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_bwd_data.hpp" +#include "element_wise_operation.hpp" +#include "reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using INT8 = int8_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(int i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + + return true; +} + +int main(int argc, char* argv[]) +{ + int data_type = 0; + int init_method = 0; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + data_type = 1; + init_method = 1; + } + else if(argc == 3) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + } + else if(argc == 18) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + conv_stride_h = std::stoi(argv[10]); + conv_stride_w = std::stoi(argv[11]); + conv_dilation_h = std::stoi(argv[12]); + conv_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); + } + else + { + printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + using AccDataType = decltype(acc_type); + + using ReferenceConvBwdInstance = + ck::tensor_operation::host::ReferenceConvBwdData; + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector input_spatial_lengths{{Hi, Wi}}; + const std::vector filter_spatial_lengths{{Y, X}}; + const std::vector output_spatial_lengths{{Ho, Wo}}; + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + }; + + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo)); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor in_n_c_hi_wi_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + // reset input to zero + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); + + // get host result + { + auto ref_conv = ReferenceConvBwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device:: + DeviceConvBwdDataPtr; + + // add device Conv instances + std::vector conv_ptrs; + + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, int8_t> && + ck::is_same_v, int8_t> && + ck::is_same_v, int8_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), 1); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result)) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + } + else + { + std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl; + } + } + + if(success) + { + std::cout << "test conv2d bwd : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test conv2d bwd: Fail " << std::endl; + return -1; + } + }; + + if(data_type == 0) + { + return Run(F32(), F32(), F32(), F32()); + } + else if(data_type == 1) + { + return Run(F16(), F16(), F16(), F32()); + } + else if(data_type == 2) + { + return Run(BF16(), BF16(), BF16(), F32()); + } + else if(data_type == 3) + { + return Run(INT8(), INT8(), INT8(), int()); + } + else + { + return 1; + } +} diff --git a/test/conv2d_bwd_weight/CMakeLists.txt b/test/conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000..e61c9299 --- /dev/null +++ b/test/conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_conv2d_bwd_weight conv2d_bwd_weight.cpp) +target_link_libraries(test_conv2d_bwd_weight PRIVATE host_tensor device_conv2d_bwd_weight_instance conv_util) diff --git a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp new file mode 100644 index 00000000..7af0fa3d --- /dev/null +++ b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "test/convnd_fwd/conv_util.hpp" +#include "profiler/include/profile_conv_bwd_weight_impl.hpp" + +int test_self() +{ + bool pass = true; + std::vector params; + + params.push_back({2, 128, 256, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + params.push_back({2, 128, 256, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + params.push_back({2, 128, 256, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_conv_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_conv_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + return pass; +} +int main(int argc, char* argv[]) +{ + int data_type = 1; + int init_method = 1; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + ck::index_t split_k = 1; + + bool pass = true; + if(argc == 1) + { + pass = test_self(); + } + else + { + if(argc == 3) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + } + else if(argc == 19) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + conv_stride_h = std::stoi(argv[10]); + conv_stride_w = std::stoi(argv[11]); + conv_dilation_h = std::stoi(argv[12]); + conv_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); + split_k = std::stoi(argv[18]); + } + else + { + printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + ck::utils::conv::ConvParams param{2, + N, + K, + C, + {Y, X}, + {Hi, Wi}, + {conv_stride_h, conv_stride_w}, + {conv_dilation_h, conv_dilation_w}, + {in_left_pad_h, in_left_pad_w}, + {in_right_pad_h, in_right_pad_w}}; + if(data_type == 0) + { + pass = ck::profiler::profile_conv_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + init_method, + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + split_k); + } + else if(data_type == 1) + { + pass = ck::profiler::profile_conv_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + init_method, + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + split_k); + } + else + { + std::cout << "Not support data type" << std::endl; + return 1; + } + } + + if(pass) + { + std::cout << "test conv2d bwd weight : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test conv2d bwd weight: Fail " << std::endl; + return -1; + } +} diff --git a/test/conv_util/CMakeLists.txt b/test/conv_util/CMakeLists.txt new file mode 100644 index 00000000..795c9ec0 --- /dev/null +++ b/test/conv_util/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_conv_util conv_util.cpp) +target_link_libraries(test_conv_util PRIVATE host_tensor conv_util) diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp new file mode 100644 index 00000000..293d9454 --- /dev/null +++ b/test/conv_util/conv_util.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" + +namespace { + +class TestConvUtil : public ::testing::Test +{ + public: + void SetNDParams(std::size_t ndims) + { + conv_params.num_dim_spatial_ = ndims; + conv_params.filter_spatial_lengths_ = std::vector(ndims, 3); + conv_params.input_spatial_lengths_ = std::vector(ndims, 71); + conv_params.conv_filter_strides_ = std::vector(ndims, 2); + conv_params.conv_filter_dilations_ = std::vector(ndims, 1); + conv_params.input_left_pads_ = std::vector(ndims, 1); + conv_params.input_right_pads_ = std::vector(ndims, 1); + } + + protected: + // ------- default 2D ------- + // input NCHW {128,192,71,71}, + // weights KCYX {256,192,3,3}, + // stride {2,2}, + // dilations {1,1}, + // padding {{1,1}, {1,1}} + ck::utils::conv::ConvParams conv_params; +}; + +} // namespace + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) +{ + ck::utils::conv::ConvParams conv_params; + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor.")); + + conv_params.conv_filter_strides_ = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); + + conv_params.conv_filter_strides_ = std::vector{2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); + + conv_params.conv_filter_strides_ = std::vector{3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) +{ + SetNDParams(1); + + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); + + conv_params.conv_filter_strides_ = std::vector{1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); + + conv_params.conv_filter_strides_ = std::vector{2}; + conv_params.input_left_pads_ = std::vector{2}; + conv_params.input_right_pads_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); + + conv_params.conv_filter_strides_ = std::vector{3}; + conv_params.input_left_pads_ = std::vector{1}; + conv_params.input_right_pads_ = std::vector{1}; + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) +{ + SetNDParams(3); + + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); + + conv_params.conv_filter_strides_ = std::vector{1, 1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}.")); + + conv_params.conv_filter_strides_ = std::vector{2, 2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}.")); + + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}.")); + + conv_params.conv_filter_strides_ = std::vector{3, 3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, + std::vector{23, 23, 23}, + "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); +} + +TEST(ConvUtil, GetHostTensorDescriptor) +{ + namespace tl = ck::tensor_layout::convolution; + std::vector dims{2, 3, 4, 5}; + HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); + EXPECT_TRUE(ck::utils::check_err( + h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{}); + EXPECT_TRUE(ck::utils::check_err( + h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!")); + + dims = std::vector{2, 3, 4}; + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err( + h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!")); + + dims = std::vector{2, 3, 4, 5, 6}; + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 1, // C + 3 * 5 * 6, // D + 3 * 6, // H + 3}, // W + "Error: wrong NDHWC dimensions strides!")); + + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{}); + EXPECT_TRUE( + ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!")); + EXPECT_TRUE(ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 4 * 5 * 6, // C + 5 * 6, // D + 6, // H + 1}, // W + "Error: wrong NCDHW dimensions strides!")); +} diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt new file mode 100644 index 00000000..554bcd18 --- /dev/null +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_util) diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp new file mode 100644 index 00000000..a5b83b9e --- /dev/null +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -0,0 +1,331 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/include/profile_convnd_bwd_data_impl.hpp" + +int main() +{ + bool pass = true; + // check 1d + std::vector params; + params.push_back({1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + params.push_back({1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + params.push_back({1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + + for(auto& param : params) + { + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + float, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<1, + int8_t, + int8_t, + int8_t, + int, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + } + + // check 2d + params.clear(); + params.push_back({2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + params.push_back({2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + params.push_back({2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + for(auto& param : params) + { + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + float, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<2, + int8_t, + int8_t, + int8_t, + int, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + } + + // check 3d + params.clear(); + params.push_back( + {3, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + params.push_back( + {3, 128, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + params.push_back( + {3, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + for(auto& param : params) + { + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + float, + float, + float, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + + pass &= ck::profiler::profile_convnd_bwd_data_impl<3, + int8_t, + int8_t, + int8_t, + int, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); + } + + if(pass) + { + std::cout << "test convnd bwd : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test convnd bwd: Fail " << std::endl; + return -1; + } +} diff --git a/test/convnd_bwd_weight/CMakeLists.txt b/test/convnd_bwd_weight/CMakeLists.txt new file mode 100644 index 00000000..e76c28bf --- /dev/null +++ b/test/convnd_bwd_weight/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_convnd_bwd_weight convnd_bwd_weight.cpp) +target_link_libraries(test_convnd_bwd_weight PRIVATE host_tensor device_convnd_bwd_weight_instance conv_util) diff --git a/test/convnd_bwd_weight/convnd_bwd_weight.cpp b/test/convnd_bwd_weight/convnd_bwd_weight.cpp new file mode 100644 index 00000000..febcef16 --- /dev/null +++ b/test/convnd_bwd_weight/convnd_bwd_weight.cpp @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "test/convnd_fwd/conv_util.hpp" +#include "profiler/include/profile_convnd_bwd_weight_impl.hpp" + +int test_self() +{ + bool pass = true; + std::vector params; + + params.push_back({1, 128, 256, 256, {1}, {7}, {2}, {1}, {0}, {0}}); + params.push_back({1, 128, 256, 256, {3}, {14}, {1}, {1}, {1}, {1}}); + params.push_back({1, 128, 256, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<1, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // bf16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<1, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + + // check 2d + params.clear(); + params.push_back({2, 128, 256, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + params.push_back({2, 128, 256, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + params.push_back({2, 128, 256, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // bf16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<2, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + + // check 2d + params.clear(); + params.push_back( + {3, 128, 256, 256, {1, 1, 1}, {4, 4, 4}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + params.push_back( + {3, 128, 256, 256, {3, 3, 3}, {4, 4, 8}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + params.push_back( + {3, 128, 256, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + for(auto& param : params) + { + // f32 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<3, + float, + float, + float, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // fp16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<3, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + + // bf16 + pass &= ck::profiler::profile_convnd_bwd_weight_impl<3, + ck::bhalf_t, + ck::bhalf_t, + ck::bhalf_t, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>( + true, // do_verification + 1, // init_method + false, // do_log + true, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.GetOutputSpatialLengths(), + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, + 2); + } + + return pass; +} +int main() +{ + // int data_type = 1; + // int init_method = 1; + + bool pass = true; + + pass = test_self(); + + if(pass) + { + std::cout << "test conv2d bwd weight : Pass" << std::endl; + return 0; + } + else + { + std::cout << "test conv2d bwd weight: Fail " << std::endl; + return -1; + } +} diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt new file mode 100644 index 00000000..444ec6c8 --- /dev/null +++ b/test/convnd_fwd/CMakeLists.txt @@ -0,0 +1,13 @@ +add_custom_target(test_convnd_fwd) + +add_gtest_executable(test_conv1d_fwd conv1d_fwd.cpp) +target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_instance conv_util) +add_dependencies(test_convnd_fwd test_conv1d_fwd) + +add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp) +target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance device_convnd_2d_fwd_instance conv_util) +add_dependencies(test_convnd_fwd test_conv2d_fwd) + +add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp) +target_link_libraries(test_conv3d_fwd PRIVATE host_tensor device_conv3d_fwd_instance conv_util) +add_dependencies(test_convnd_fwd test_conv3d_fwd) diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp new file mode 100644 index 00000000..4d2473f0 --- /dev/null +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "test/convnd_fwd/conv_util.hpp" + +namespace { + +class Conv1dFwdNWCInstances : public ::testing::Test +{ + public: + template + bool test_conv1d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 1, 4, 256, 64, {3}, {71}, {2}, {2}, {2}, {2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {1}, {1}, {0}, {0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {2}, {1}, {0}, {0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv1DFwdNWC, IntegerValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = float; + + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv1DFwdNWC, FloatingPointValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(0.1); + run_engine.SetRtol(1e-2); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST_F(Conv1dFwdNWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp new file mode 100644 index 00000000..f4580578 --- /dev/null +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "test/convnd_fwd/conv_util.hpp" + +namespace { + +class Conv2dFwdNHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_default_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_default_); + } + } + + template + bool test_filter1x1_stride1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), + params_filter1x1_stride1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_stride1_pad0_); + } + } + + template + bool test_filter1x1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_filter1x1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_pad0_); + } + } + + template + bool test_oddC() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_oddC_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_oddC_{ + 2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv2DFwdNHWC, IntegerValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + using T = float; + + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-4); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv2DFwdNHWC, FloatingPointValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(2e-4); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC()); } +TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default) +{ + EXPECT_TRUE(this->test_default(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_default) +{ + EXPECT_TRUE(this->test_default(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp new file mode 100644 index 00000000..0cc2b241 --- /dev/null +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -0,0 +1,317 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/conv_util.hpp" + +#include "test/convnd_fwd/conv_util.hpp" + +namespace { + +class Conv3dFwdNDHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv3d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 3, 4, 256, 64, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv3DFwdNDHWC, IntegerValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = float; + + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv3DFwdNDHWC, FloatingPointValues) +{ + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-3); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); +} + +TEST(Conv3DFwdNDHWC, InputOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + using T = float; + + // >2GB Input + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{32, 1000, 1000}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, FiltersOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + using T = float; + + // >2GB Filters + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{4, 1000, 1000}; + params.input_spatial_lengths_ = std::vector{16, 16, 16}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST(Conv3DFwdNDHWC, OutputOver2GB) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using namespace ck::utils; + using T = float; + + // >2GB Output + conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{1, 1, 1}; + params.input_spatial_lengths_ = std::vector{1000, 1000, 30}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{2, 2, 2}; + params.input_right_pads_ = std::vector{2, 2, 2}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, + nullptr, + nullptr, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + params.GetOutputSpatialLengths(), + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + PassThrough{}, + PassThrough{}, + PassThrough{}); + EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); +} + +TEST_F(Conv3dFwdNDHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv3dFwdNDHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv3dFwdNDHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv3dFwdNDHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp new file mode 100644 index 00000000..8c35570b --- /dev/null +++ b/test/convnd_fwd/conv_util.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; +namespace instance { + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace test { +namespace conv { + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // Accumulator data type. + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + SpatialDims, // SptialDims + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +void get_test_convolution_fwd_instance(std::vector& instances) +{ + using ConvInstanceT = + DeviceConvNDFwdInstance; + instances.emplace_back(std::make_unique()); +} + +// TODO (aosewski) +// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K +// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances +// structures. +template +struct ConvolutionNDFwdInstances; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +} // namespace conv +} // namespace test diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt new file mode 100644 index 00000000..83b3c1e2 --- /dev/null +++ b/test/gemm/CMakeLists.txt @@ -0,0 +1,15 @@ +add_test_executable(test_gemm_fp32 gemm_fp32.cpp) +target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_fp16 gemm_fp16.cpp) +target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_bf16 gemm_bf16.cpp) +target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) +target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_int8 gemm_int8.cpp) +target_link_libraries(test_gemm_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp new file mode 100644 index 00000000..6130ec9b --- /dev/null +++ b/test/gemm/gemm_bf16.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = ck::bhalf_t; + using BDataType = ck::bhalf_t; + using CDataType = ck::bhalf_t; + using AccDataType = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp new file mode 100644 index 00000000..05e696ca --- /dev/null +++ b/test/gemm/gemm_fp16.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + using AccDataType = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp new file mode 100644 index 00000000..3e141d7b --- /dev/null +++ b/test/gemm/gemm_fp32.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + using AccDataType = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_fp64.cpp b/test/gemm/gemm_fp64.cpp new file mode 100644 index 00000000..96dc459a --- /dev/null +++ b/test/gemm/gemm_fp64.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = double; + using BDataType = double; + using CDataType = double; + using AccDataType = double; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp new file mode 100644 index 00000000..c7d79782 --- /dev/null +++ b/test/gemm/gemm_int8.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + using AccDataType = int32_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool pass = true; + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + const auto gemmPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemmPtr : gemmPtrs) + { + pass &= ck::gemm_util::TestGemm, + ADataType, + BDataType, + CDataType, + AccDataType, + decltype(a_layout), + decltype(b_layout), + decltype(c_layout), + PassThrough, + PassThrough, + PassThrough>{}(gemmPtr); + } + + return pass; + }; + + bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) && + test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{}); + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp new file mode 100644 index 00000000..b83b9c61 --- /dev/null +++ b/test/gemm/gemm_util.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace gemm_util { + +struct GemmParams +{ + GemmParams() + : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) + { + } + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, + const ck::gemm_util::GemmParams& params, + const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); + + auto invoker_ptr = gemmPtr->MakeInvokerPointer(); + auto argument_ptr = + gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + params.M, + params.N, + params.K, + params.StrideA, + params.StrideB, + params.StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(gemmPtr->IsSupportedArgument(argument_ptr.get())) + { + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + invoker_ptr->Run(argument_ptr.get()); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; + } + else + { + std::cout << "device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return false; + } +} + +template +struct TestGemm +{ + auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_k_n, BDataType{}); + + return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceGemmPtr_& gemmPtr) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + std::cout << gemmPtr->GetTypeString() << std::endl; + + // Arrange + ck::gemm_util::GemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::gemm_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::gemm_util::RunDeviceGEMM( + gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +} // namespace gemm_util +} // namespace ck diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt new file mode 100644 index 00000000..74b787ac --- /dev/null +++ b/test/gemm_reduce/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp new file mode 100644 index 00000000..16f787e0 --- /dev/null +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/include/profile_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + bool pass = true; + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N); + + pass = pass && + ck::profiler:: + profile_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N); + + if(pass) + { + std::cout << "test GEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test GEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt new file mode 100644 index 00000000..ab1d016c --- /dev/null +++ b/test/gemm_split_k/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_gemm_split_k gemm_split_k.cpp) +target_link_libraries(test_gemm_split_k PRIVATE host_tensor) +target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance) diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp new file mode 100644 index 00000000..27d9d052 --- /dev/null +++ b/test/gemm_split_k/gemm_split_k.cpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/library/utility/host_gemm.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + + return true; +} + +struct gemmArgs +{ + GemmMatrixLayout layout; + int M; + int N; + int K; + int StrideA; + int StrideB; + int StrideC; + int KBatch; +}; + +int test_gemm(const gemmArgs& args) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + bool a_row_major, b_row_major, c_row_major; + + switch(args.layout) + { + case GemmMatrixLayout::MK_KN_MN: + a_row_major = true; + b_row_major = true; + c_row_major = true; + break; + case GemmMatrixLayout::MK_NK_MN: + a_row_major = true; + b_row_major = false; + c_row_major = true; + break; + case GemmMatrixLayout::KM_KN_MN: + a_row_major = false; + b_row_major = true; + c_row_major = true; + break; + case GemmMatrixLayout::KM_NK_MN: + a_row_major = false; + b_row_major = false; + c_row_major = true; + break; + default: printf("not supported layout"); return 1; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) { + if(row_major) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major)); + Tensor b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major)); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); + + // init data + std::size_t num_thread = 1; + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + host_gemm_mk_kn_mn(a_m_k, + b_k_n, + c_m_n_host_result, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + auto test = [&](auto a_layout, auto b_layout, auto c_layout) { + bool success = false; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; + + const auto gemm_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + args.M, + args.N, + args.K, + args.StrideA, + args.StrideB, + args.StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + args.KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get()); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(!check_out(c_m_n_host_result, c_m_n_device_result)) + { + success = false; + break; + } + success = true; + } + } + + return success; + }; + + bool success = false; + + if(args.layout == GemmMatrixLayout::MK_KN_MN) + { + success = test(Row{}, Row{}, Row{}); + } + else if(args.layout == GemmMatrixLayout::MK_NK_MN) + { + success = test(Row{}, Col{}, Row{}); + } + else if(args.layout == GemmMatrixLayout::KM_KN_MN) + { + success = test(Col{}, Row{}, Row{}); + } + else + { + success = test(Col{}, Col{}, Row{}); + } + + auto error_code = 0; + if(success) + { + std::cout << "test split k : Pass" << std::endl; + } + else + { + std::cout << "test split k: Fail " << std::endl; + error_code = -1; // test needs to report failure + } + return error_code; +} + +int main(int argc, char* argv[]) +{ + std::vector test_cases; + if(argc == 1) + { + test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}}; + // JD: Populate with more and meaningful + return 0; + } + else if(argc == 9) + { + const auto layout = static_cast(std::stoi(argv[1])); + + const int M = std::stoi(argv[2]); + const int N = std::stoi(argv[3]); + const int K = std::stoi(argv[4]); + + const int StrideA = std::stoi(argv[5]); + const int StrideB = std::stoi(argv[6]); + const int StrideC = std::stoi(argv[7]); + const int KBatch = std::stoi(argv[8]); + test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}}; + } + else + { + printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n"); + return -1; + } + for(const auto& kinder : test_cases) + { + const auto res = test_gemm(kinder); + if(!res) + return -1; + } + return 0; +} diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000..f04ee770 --- /dev/null +++ b/test/grouped_gemm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp) +target_link_libraries(test_grouped_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance) diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp new file mode 100644 index 00000000..d2f9fc0f --- /dev/null +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGroupedGemmPtr_ = ck::tensor_operation::device::DeviceGroupedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +} +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace { + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) +{ + int group_count = rand() % 10 + 1; + + // GEMM shape + std::vector gemm_shapes; + std::vector p_a, p_b; + std::vector p_c; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = 256 + 256 * (rand() % 10); + int N = 256 + 256 * (rand() % 10); + int K = 128 + 128 * (rand() % 10); + + int AStride = std::is_same::value ? K : M; + int BStride = std::is_same::value ? N : K; + int CStride = std::is_same::value ? N : M; + + gemm_shapes.push_back({M, N, K, AStride, BStride, CStride}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + ; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); + b_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); + c_host_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + c_device_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); + b_tensors_device.emplace_back( + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + + auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( + p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get())); + + groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get()); + + for(std::size_t i = 0; i < gemm_shapes.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get())) + { + return false; + } + + ref_invoker.Run(ref_argument); + + bool res = ck::utils::check_err(c_host_tensors[i].mData, c_device_tensors[i].mData); + + std::cout << "group_id: " << i << (res ? " SUCCESS" : " FAILURE") << std::endl; + + if(!res) + return false; + } + + return true; +} + +} // anonymous namespace + +int main() +{ + std::vector groupedGemmPtrs; + ck::tensor_operation::device::instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(groupedGemmPtrs); + + bool res = true; + + for(auto& gemmPtr : groupedGemmPtrs) + { + res &= TestGroupedGemm(gemmPtr); + } + + std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return res ? 0 : 1; +} diff --git a/test/magic_number_division/CMakeLists.txt b/test/magic_number_division/CMakeLists.txt new file mode 100644 index 00000000..c7d3f45c --- /dev/null +++ b/test/magic_number_division/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_magic_number_division magic_number_division.cpp) +target_link_libraries(test_magic_number_division PRIVATE host_tensor) diff --git a/test/magic_number_division/magic_number_division.cpp b/test/magic_number_division/magic_number_division.cpp new file mode 100644 index 00000000..680fddf1 --- /dev/null +++ b/test/magic_number_division/magic_number_division.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/magic_division.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +__global__ void gpu_magic_number_division(uint32_t magic_multiplier, + uint32_t magic_shift, + const int32_t* p_dividend, + int32_t* p_result, + uint64_t num) +{ + uint64_t global_thread_num = blockDim.x * gridDim.x; + + uint64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + for(uint64_t data_id = global_thread_id; data_id < num; data_id += global_thread_num) + { + p_result[data_id] = + ck::MagicDivision::DoMagicDivision(p_dividend[data_id], magic_multiplier, magic_shift); + } +} + +__global__ void +gpu_naive_division(int32_t divisor, const int32_t* p_dividend, int32_t* p_result, uint64_t num) +{ + uint64_t global_thread_num = blockDim.x * gridDim.x; + + uint64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + for(uint64_t data_id = global_thread_id; data_id < num; data_id += global_thread_num) + { + p_result[data_id] = p_dividend[data_id] / divisor; + } +} + +__host__ void cpu_magic_number_division(uint32_t magic_multiplier, + uint32_t magic_shift, + const int32_t* p_dividend, + int32_t* p_result, + uint64_t num) +{ + for(uint64_t data_id = 0; data_id < num; ++data_id) + { + p_result[data_id] = + ck::MagicDivision::DoMagicDivision(p_dividend[data_id], magic_multiplier, magic_shift); + } +} + +int main(int, char*[]) +{ + uint64_t num_divisor = 4096; + uint64_t num_dividend = 1L << 16; + + std::vector divisors_host(num_divisor); + std::vector dividends_host(num_dividend); + + // generate divisor + for(uint64_t i = 0; i < num_divisor; ++i) + { + divisors_host[i] = i + 1; + } + + // generate dividend + for(uint64_t i = 0; i < num_divisor; ++i) + { + dividends_host[i] = i; + } + + DeviceMem dividends_dev_buf(sizeof(int32_t) * num_dividend); + DeviceMem naive_result_dev_buf(sizeof(int32_t) * num_dividend); + DeviceMem magic_result_dev_buf(sizeof(int32_t) * num_dividend); + + std::vector naive_result_host(num_dividend); + std::vector magic_result_host(num_dividend); + std::vector magic_result_host2(num_dividend); + + dividends_dev_buf.ToDevice(dividends_host.data()); + + bool pass = true; + + for(std::size_t i = 0; i < num_divisor; ++i) + { + // run naive division on GPU + gpu_naive_division<<<1024, 256>>>( + divisors_host[i], + static_cast(dividends_dev_buf.GetDeviceBuffer()), + static_cast(naive_result_dev_buf.GetDeviceBuffer()), + num_dividend); + + // calculate magic number + uint32_t magic_multiplier, magic_shift; + + ck::tie(magic_multiplier, magic_shift) = + ck::MagicDivision::CalculateMagicNumbers(divisors_host[i]); + + // run magic division on GPU + gpu_magic_number_division<<<1024, 256>>>( + magic_multiplier, + magic_shift, + static_cast(dividends_dev_buf.GetDeviceBuffer()), + static_cast(magic_result_dev_buf.GetDeviceBuffer()), + num_dividend); + + naive_result_dev_buf.FromDevice(naive_result_host.data()); + magic_result_dev_buf.FromDevice(magic_result_host.data()); + + bool res = ck::utils::check_err(magic_result_host, naive_result_host); + + if(!res) + { + pass = false; + continue; + } + + cpu_magic_number_division(magic_multiplier, + magic_shift, + dividends_host.data(), + magic_result_host2.data(), + num_dividend); + + res = ck::utils::check_err(magic_result_host2, naive_result_host); + + if(!res) + { + pass = false; + continue; + } + } + + if(pass) + { + std::cout << "test magic number division: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test magic number division: Fail" << std::endl; + return -1; + } +} diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt new file mode 100644 index 00000000..4e11b049 --- /dev/null +++ b/test/reduce/CMakeLists.txt @@ -0,0 +1,7 @@ +add_test_executable(test_reduce_no_index reduce_no_index.cpp) +add_test_executable(test_reduce_with_index reduce_with_index.cpp) +target_link_libraries(test_reduce_no_index PRIVATE host_tensor) +target_link_libraries(test_reduce_no_index PRIVATE device_reduce_instance) +target_link_libraries(test_reduce_with_index PRIVATE host_tensor) +target_link_libraries(test_reduce_with_index PRIVATE device_reduce_instance) + diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp new file mode 100644 index 00000000..475ebfd0 --- /dev/null +++ b/test/reduce/reduce_no_index.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/utility/host_common_util.hpp" +#include "profiler/include/profile_reduce_impl.hpp" + +using namespace ck; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDimensions", required_argument, nullptr, 'R'}, + {"scales", required_argument, nullptr, 'S'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + std::vector scales; + + int data_type; + int init_method = 1; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths " + "(only 4-d tensor supported)" + << std::endl; + std::cout << "--reduceDimensions or -R comma seperated list of dimension indexes to reduce " + "(only 1 or 3 or 4 dimensions supported)" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "Arg1 -- data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- init method(0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:S:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind]); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(inLengths.size() != 4 || + (reduceDims.size() != 1 && reduceDims.size() != 3 && reduceDims.size() != 4)) + return (-1); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +bool test_reduce_no_index(int data_type, + int init_method, + std::vector reduceDims, + std::vector inLengths, + ReduceTensorOp reduceOpId, + bool propagateNan, + float alpha, + float beta) +{ + using ck::profiler::profile_reduce_impl; + + bool result = true; + + if(data_type == 0) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 1) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 3) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 5) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 6) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + + return (result); +}; + +constexpr ReduceTensorOp reduceOpId = ReduceTensorOp::AVG; +constexpr bool propagateNan = false; + +int main(int argc, char* argv[]) +{ + SimpleAppArgs args; + + bool result = true; + + if(argc == 1) + { + int data_type = 1; + int init_method = 2; + std::vector inLengths{64, 4, 280, 80}; + std::vector> v_reduceDims{ + {0, 1, 2, 3}, {0, 1, 2}, {1, 2, 3}, {0, 1, 3}, {0, 2, 3}, {0}, {1}, {2}, {3}}; + + for(auto& reduceDims : v_reduceDims) + result = result && test_reduce_no_index(data_type, + init_method, + reduceDims, + inLengths, + reduceOpId, + propagateNan, + 1.0f, + 0.0f); + } + else + { + if(args.processArgs(argc, argv) < 0) + { + throw std::runtime_error( + "Invalid input arguments, test_reduce_no_index could not be executed!"); + }; + + result = test_reduce_no_index(args.data_type, + args.init_method, + args.reduceDims, + args.inLengths, + reduceOpId, + propagateNan, + args.scales[0], + args.scales[1]); + } + + std::cout << "test_reduce_no_index ..... " << (result ? "SUCCESS" : "FAILURE") << std::endl; + + return (result ? 0 : -1); +} diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp new file mode 100644 index 00000000..c319dca6 --- /dev/null +++ b/test/reduce/reduce_with_index.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/utility/host_common_util.hpp" +#include "profiler/include/profile_reduce_impl.hpp" + +using namespace ck; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDimensions", required_argument, nullptr, 'R'}, + {"scales", required_argument, nullptr, 'S'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + std::vector scales; + + int data_type; + int init_method = 1; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths " + "(only 4-d tensor supported)" + << std::endl; + std::cout << "--reduceDimensions or -R comma seperated list of dimension indexes to reduce " + "(only 1 or 3 or 4 dimensions supported)" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "Arg1 -- data type (1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- init method(0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:S:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind]); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(inLengths.size() != 4 || + (reduceDims.size() != 1 && reduceDims.size() != 3 && reduceDims.size() != 4)) + return (-1); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +bool test_reduce_with_index(int data_type, + int init_method, + std::vector reduceDims, + std::vector inLengths, + ReduceTensorOp reduceOpId, + bool propagateNan, + float alpha, + float beta) +{ + using ck::profiler::profile_reduce_impl; + + bool result = true; + + if(data_type == 0) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 1) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 3) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 5) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 6) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + + return (result); +}; + +constexpr ReduceTensorOp reduceOpId = ReduceTensorOp::AMAX; +constexpr bool propagateNan = false; + +int main(int argc, char* argv[]) +{ + SimpleAppArgs args; + + bool result = true; + + if(argc == 1) + { + int data_type = 1; + int init_method = 2; + std::vector inLengths{64, 4, 280, 80}; + std::vector> v_reduceDims{ + {0, 1, 2, 3}, {0, 1, 2}, {1, 2, 3}, {0, 1, 3}, {0, 2, 3}, {0}, {1}, {2}, {3}}; + + for(auto& reduceDims : v_reduceDims) + result = result && test_reduce_with_index(data_type, + init_method, + reduceDims, + inLengths, + reduceOpId, + propagateNan, + 1.0f, + 0.0f); + } + else + { + if(args.processArgs(argc, argv) < 0) + { + throw std::runtime_error( + "Invalid input arguments, test_reduce_with_index could not be executed!"); + }; + + result = test_reduce_with_index(args.data_type, + args.init_method, + args.reduceDims, + args.inLengths, + reduceOpId, + propagateNan, + args.scales[0], + args.scales[1]); + } + + std::cout << "test_reduce_with_index ..... " << (result ? "SUCCESS" : "FAILURE") << std::endl; + + return (result ? 0 : -1); +} diff --git a/test/reference_conv_fwd/CMakeLists.txt b/test/reference_conv_fwd/CMakeLists.txt new file mode 100644 index 00000000..04b720b1 --- /dev/null +++ b/test/reference_conv_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_reference_conv_fwd reference_conv_fwd.cpp) +target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor conv_util) diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp new file mode 100644 index 00000000..6ab5a23f --- /dev/null +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace { +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template , + typename FillWeightsOp = ck::utils::FillConstant> +Tensor +run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp{}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) +{ + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::utils::conv::get_host_tensor_descriptor(input_dims, InLayout{})); + Tensor weights( + ck::utils::conv::get_host_tensor_descriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); + + fill_input_op(input.begin(), input.end()); + fill_weights_op(weights.begin(), weights.end()); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + return host_output; +} + +} // anonymous namespace + +TEST(ReferenceConvolutionFWD, Conv2DNHWC) +{ + ck::utils::conv::ConvParams params; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6}; + params.conv_filter_strides_ = std::vector{1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1}; + params.input_left_pads_ = std::vector{0, 0}; + params.input_right_pads_ = std::vector{0, 0}; + + auto out_tensor = run_reference_convolution_forward<2>(params); + std::vector ref_dims{1, 1, 4, 4}; + std::vector ref_data{130.5, + 148.5, + 166.5, + 184.5, + 238.5, + 256.5, + 274.5, + 292.5, + 346.5, + 364.5, + 382.5, + 400.5, + 454.5, + 472.5, + 490.5, + 508.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv2DNHWCStridesDilationsPadding) +{ + ck::utils::conv::ConvParams params; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12}; + params.conv_filter_strides_ = std::vector{2, 2}; + params.conv_filter_dilations_ = std::vector{2, 2}; + params.input_left_pads_ = std::vector{1, 1}; + params.input_right_pads_ = std::vector{1, 1}; + + auto out_tensor = run_reference_convolution_forward<2>(params); + std::vector ref_dims = std::vector{1, 2, 5, 5}; + std::vector ref_data{ + 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., + 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, + 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, + 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, + 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWC) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{6}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{0}; + params.input_right_pads_ = std::vector{0}; + + auto out_tensor = + run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 1, 4}; + std::vector ref_data{7.5, 13.5, 19.5, 25.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWCStridesDilationsPadding) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{12}; + params.conv_filter_strides_ = std::vector{2}; + params.conv_filter_dilations_ = std::vector{2}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + auto out_tensor = + run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 2, 5}; + std::vector ref_data{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv1DNWCSameOutputSize) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 1; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{16}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; + + auto out_tensor2 = run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + + std::vector ref_dims{2, 16, 16}; + std::vector ref_data{ + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 27., 27., 27., 27., 27., 27., 27., 27., + 27., 27., 27., 27., 27., 27., 27., 27., + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; + EXPECT_TRUE(ck::utils::check_err( + out_tensor2.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv3DNCDHW) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6, 6}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; + + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 1, 4, 4, 4}; + std::vector ref_data{ + 407.7, 410.40002, 413.09998, 415.80002, 423.90002, 426.6, 429.30002, 432., + 440.1, 442.80002, 445.5, 448.2, 456.30002, 459., 461.7, 464.40002, + 504.90002, 507.6, 510.30002, 513., 521.1, 523.8, 526.5, 529.2001, + 537.3, 540., 542.7001, 545.4, 553.5, 556.2001, 558.9, 561.6, + 602.10004, 604.8, 607.5, 610.2, 618.3, 621., 623.7, 626.4, + 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, + 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, + 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; + EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 1]: wrong output tensor dimensions!")); + EXPECT_TRUE( + ck::utils::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!")); +} + +TEST(ReferenceConvolutionFWD, Conv3DNCDHWStridesDilations) +{ + ck::utils::conv::ConvParams params; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12, 12}; + params.conv_filter_strides_ = std::vector{3, 3, 3}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; + + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, ck::utils::FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 2, 4, 4, 4}; + std::vector ref_data{ + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801, + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; + EXPECT_TRUE(ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 2]: wrong output tensor dimensions!")); + EXPECT_TRUE(ck::utils::check_err( + out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f)); +} diff --git a/test/softmax/CMakeLists.txt b/test/softmax/CMakeLists.txt new file mode 100644 index 00000000..da80e372 --- /dev/null +++ b/test/softmax/CMakeLists.txt @@ -0,0 +1,11 @@ +add_custom_target(test_softmax) + +add_gtest_executable(test_softmax_fp32 test_softmax_fp32.cpp) +add_gtest_executable(test_softmax_fp16 test_softmax_fp16.cpp) +add_gtest_executable(test_softmax_int8 test_softmax_int8.cpp) +target_link_libraries(test_softmax_fp32 PRIVATE host_tensor) +target_link_libraries(test_softmax_fp16 PRIVATE host_tensor) +target_link_libraries(test_softmax_int8 PRIVATE host_tensor) +add_dependencies(test_softmax test_softmax_fp32) +add_dependencies(test_softmax test_softmax_fp16) +add_dependencies(test_softmax test_softmax_int8) \ No newline at end of file diff --git a/test/softmax/test_softmax_fp16.cpp b/test/softmax/test_softmax_fp16.cpp new file mode 100644 index 00000000..cce6a422 --- /dev/null +++ b/test/softmax/test_softmax_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxFP16 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<4>>, // mixed precision + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<8>, I<8>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxFP16, KernelTypes); +TYPED_TEST(TestSoftmaxFP16, Test_FP16) { this->Run(); } diff --git a/test/softmax/test_softmax_fp32.cpp b/test/softmax/test_softmax_fp32.cpp new file mode 100644 index 00000000..4301a5ae --- /dev/null +++ b/test/softmax/test_softmax_fp32.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxFP32 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<8>>, // mixed precision + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<4>, I<4>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxFP32, KernelTypes); +TYPED_TEST(TestSoftmaxFP32, Test_FP32) { this->Run(); } diff --git a/test/softmax/test_softmax_int8.cpp b/test/softmax/test_softmax_int8.cpp new file mode 100644 index 00000000..dde16529 --- /dev/null +++ b/test/softmax/test_softmax_int8.cpp @@ -0,0 +1,30 @@ +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxINT8 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<16>, I<16>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<64>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<16>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<32>, I<1>, I<16>, I<16>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<64>, I<1>, I<16>, I<16>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxINT8, KernelTypes); +TYPED_TEST(TestSoftmaxINT8, Test_INT8) { this->Run(); } diff --git a/test/softmax/test_softmax_util.hpp b/test/softmax/test_softmax_util.hpp new file mode 100644 index 00000000..d2082509 --- /dev/null +++ b/test/softmax/test_softmax_util.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/number.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { + +template +std::string serialize_range(const Range& range) +{ + std::stringstream ss; + for(auto& r : range) + { + ss << r << ", "; + } + std::string str = ss.str(); + return std::string(str.begin(), str.end() - 2); +} + +template +class TestSoftmax : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using AccDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + static constexpr index_t Rank = std::tuple_element_t<3, Tuple>{}.value; + static constexpr index_t NumReduceDim = std::tuple_element_t<4, Tuple>{}.value; + static constexpr index_t BlockSize = std::tuple_element_t<5, Tuple>{}.value; + static constexpr index_t MThreadClusterSize = std::tuple_element_t<6, Tuple>{}.value; + static constexpr index_t KThreadClusterSize = std::tuple_element_t<7, Tuple>{}.value; + static constexpr index_t MThreadSliceSize = std::tuple_element_t<8, Tuple>{}.value; + static constexpr index_t KThreadSliceSize = std::tuple_element_t<9, Tuple>{}.value; + static constexpr index_t InSrcVectorDim = std::tuple_element_t<10, Tuple>{}.value; + static constexpr index_t InSrcVectorSize = std::tuple_element_t<11, Tuple>{}.value; + static constexpr index_t OutDstVectorSize = std::tuple_element_t<12, Tuple>{}.value; + + using ReferenceInstance = + tensor_operation::host::ReferenceSoftmax; + + using DeviceInstance = tensor_operation::device::DeviceSoftmax; + + TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {} + + void RunSingle(std::vector in_length, AccDataType alpha, AccDataType beta) + { + std::vector reduce_dims(NumReduceDim); + std::iota(reduce_dims.begin(), reduce_dims.end(), Rank - NumReduceDim); + + Tensor in(in_length); + Tensor out(in_length); + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + Tensor out_ref(out); + + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + in_dev.ToDevice(in.mData.data()); + out_dev.ToDevice(out.mData.data()); + + std::vector i_in_lengths(in.mDesc.GetLengths().begin(), + in.mDesc.GetLengths().end()); + std::vector i_in_strides(in.mDesc.GetStrides().begin(), + in.mDesc.GetStrides().end()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer(i_in_lengths, + i_in_strides, + reduce_dims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + // std::cout << "Skipped due to unsupported argument: " + // << "input lengths = [" << serialize_range(in_length) << "], " + // << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get()); + + ref_instance_invoker_.Run({in, out_ref, alpha, beta, reduce_dims}); + + out_dev.FromDevice(out.mData.data()); + + bool pass; + + if(std::is_same::value) + { + EXPECT_TRUE(pass = ck::utils::check_err( + out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1)); + } + else + { + EXPECT_TRUE(pass = ck::utils::check_err(out.mData, out_ref.mData)); + } + + if(!pass) + { + FAIL() << "Failure in input lengths = [" << serialize_range(in_length) << "], " + << "scaler = [" << alpha << ", " << beta << "]."; + } + } + + void Run() + { + for(auto in_length : this->in_lengths_) + { + for(auto scale : this->scales_) + { + this->RunSingle(in_length, scale[0], scale[1]); + } + } + } + + std::vector> in_lengths_ = { + {1, 8, 128}, {2, 128, 1024}, {3, 9, 1032}, {4, 4, 2048}, {8, 1, 8192}}; + std::vector> scales_ = {{1, 0}, {1, 1}, {0, 1}, {2, 2}}; + + typename ReferenceInstance::Invoker ref_instance_invoker_; +}; +} // namespace ck diff --git a/test/space_filling_curve/CMakeLists.txt b/test/space_filling_curve/CMakeLists.txt new file mode 100644 index 00000000..a5272680 --- /dev/null +++ b/test/space_filling_curve/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_space_filling_curve space_filling_curve.cpp) diff --git a/test/space_filling_curve/space_filling_curve.cpp b/test/space_filling_curve/space_filling_curve.cpp new file mode 100644 index 00000000..500717dd --- /dev/null +++ b/test/space_filling_curve/space_filling_curve.cpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +using namespace ck; + +void traverse_using_space_filling_curve(); + +int main(int argc, char** argv) +{ + (void)argc; + (void)argv; + + traverse_using_space_filling_curve(); + + return 0; +} + +void traverse_using_space_filling_curve() +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + using TensorLengths = Sequence<16, 10, 9>; + using DimAccessOrder = Sequence<2, 0, 1>; + using ScalarsPerAccess = Sequence<4, 2, 3>; + using SpaceFillingCurve = SpaceFillingCurve; + + constexpr auto expected = make_tuple(make_tuple(0, 0, 0), + make_tuple(0, 2, 0), + make_tuple(0, 4, 0), + make_tuple(0, 6, 0), + make_tuple(0, 8, 0), + make_tuple(4, 8, 0), + make_tuple(4, 6, 0), + make_tuple(4, 4, 0), + make_tuple(4, 2, 0), + make_tuple(4, 0, 0), + make_tuple(8, 0, 0), + make_tuple(8, 2, 0), + make_tuple(8, 4, 0), + make_tuple(8, 6, 0), + make_tuple(8, 8, 0), + make_tuple(12, 8, 0), + make_tuple(12, 6, 0), + make_tuple(12, 4, 0), + make_tuple(12, 2, 0), + make_tuple(12, 0, 0), + make_tuple(12, 0, 3), + make_tuple(12, 2, 3), + make_tuple(12, 4, 3), + make_tuple(12, 6, 3), + make_tuple(12, 8, 3), + make_tuple(8, 8, 3), + make_tuple(8, 6, 3), + make_tuple(8, 4, 3), + make_tuple(8, 2, 3), + make_tuple(8, 0, 3), + make_tuple(4, 0, 3), + make_tuple(4, 2, 3), + make_tuple(4, 4, 3), + make_tuple(4, 6, 3), + make_tuple(4, 8, 3), + make_tuple(0, 8, 3), + make_tuple(0, 6, 3), + make_tuple(0, 4, 3), + make_tuple(0, 2, 3), + make_tuple(0, 0, 3), + make_tuple(0, 0, 6), + make_tuple(0, 2, 6), + make_tuple(0, 4, 6), + make_tuple(0, 6, 6), + make_tuple(0, 8, 6), + make_tuple(4, 8, 6), + make_tuple(4, 6, 6), + make_tuple(4, 4, 6), + make_tuple(4, 2, 6), + make_tuple(4, 0, 6), + make_tuple(8, 0, 6), + make_tuple(8, 2, 6), + make_tuple(8, 4, 6), + make_tuple(8, 6, 6), + make_tuple(8, 8, 6), + make_tuple(12, 8, 6), + make_tuple(12, 6, 6), + make_tuple(12, 4, 6), + make_tuple(12, 2, 6), + make_tuple(12, 0, 6)); + + constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{}, + math::multiplies{}, + Number<1>{})); + + static_for<1, num_access, 1>{}([&](auto i) { + constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i); + + static_assert(idx_curr[I0] == expected[i][I0]); + static_assert(idx_curr[I1] == expected[i][I1]); + static_assert(idx_curr[I2] == expected[i][I2]); + + constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i); + constexpr auto expected_step = expected[i - I1] - expected[i]; + static_assert(backward_step[I0] == expected_step[I0]); + static_assert(backward_step[I1] == expected_step[I1]); + static_assert(backward_step[I2] == expected_step[I2]); + }); + + static_for<0, num_access - 1, 1>{}([&](auto i) { + constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i); + + static_assert(idx_curr[I0] == expected[i][I0]); + static_assert(idx_curr[I1] == expected[i][I1]); + static_assert(idx_curr[I2] == expected[i][I2]); + + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i); + constexpr auto expected_step = expected[i + I1] - expected[i]; + static_assert(forward_step[I0] == expected_step[I0]); + static_assert(forward_step[I1] == expected_step[I1]); + static_assert(forward_step[I2] == expected_step[I2]); + }); +}