diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..cdf5b64d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,48 @@
+# Compiled Object files
+*.slo
+*.lo
+*.o
+*.obj
+
+# Precompiled Headers
+*.gch
+*.pch
+*.ipch
+
+# Compiled Dynamic libraries
+*.so
+*.dylib
+*.dll
+
+# Fortran module files
+*.mod
+
+# Compiled Static libraries
+*.lai
+*.la
+*.a
+*.lib
+
+# Executables
+*.exe
+*.out
+*.app
+
+# vim tags
+tags
+.tags
+.*.swp
+
+# Editors
+.vscode
+
+# build-in-source directory
+build*
+
+# emacs temporary/backup files
+.\#*
+\#*\#
+*~
+
+# GDB temporary files
+.gdb_history
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 306e6ca6..9f706207 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,10 +1,26 @@
-cmake_minimum_required(VERSION 3.5)
+cmake_minimum_required(VERSION 3.14)
+
+# Check support for CUDA/HIP in Cmake
project(composable_kernel)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
+enable_testing()
+
+set(ROCM_SYMLINK_LIBS OFF)
+find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm)
+
+include(ROCMInstallTargets)
+include(ROCMPackageConfigHelpers)
+include(ROCMSetupVersion)
+include(ROCMInstallSymlinks)
+include(ROCMCreatePackage)
include(CheckCXXCompilerFlag)
+rocm_setup_version(VERSION 0.2.0)
+include(TargetFlags)
+list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
+
## C++
enable_language(CXX)
set(CMAKE_CXX_STANDARD 17)
@@ -30,35 +46,42 @@ message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
-set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY})
## HIP
find_package(HIP REQUIRED)
-message(STATUS "Build with HIP ${hip_VERSION}")
-
-## half
-#find_path(HALF_INCLUDE_DIR half.hpp)
-message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
-
-# CMAKE_CXX_FLAGS
-SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
-if(BUILD_DEV)
- string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
+# Override HIP version in config.h, if necessary.
+# The variables set by find_package() can't be overwritten,
+# therefore let's use intermediate variables.
+set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}")
+set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}")
+set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}")
+if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR )
+ set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}")
+ message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}")
endif()
-message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
+if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR )
+ set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}")
+ message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}")
+endif()
+if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH )
+ set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}")
+ message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}")
+endif()
+message(STATUS "Build with HIP ${HIP_VERSION}")
## tidy
include(EnableCompilerWarnings)
-set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
+set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
- set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
+ set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
# Enable tidy on hip
-elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
- set(MIOPEN_TIDY_ERRORS ALL)
+elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
+ set(CK_TIDY_ERRORS ALL)
endif()
+
include(ClangTidy)
enable_clang_tidy(
CHECKS
@@ -150,13 +173,12 @@ enable_clang_tidy(
-cppcoreguidelines-narrowing-conversions
-altera-struct-pack-align
-cppcoreguidelines-prefer-member-initializer
-
- ${MIOPEN_TIDY_CHECKS}
- ${MIOPEN_TIDY_ERRORS}
+ ${CK_TIDY_CHECKS}
+ ${CK_TIDY_ERRORS}
HEADER_FILTER
"\.hpp$"
EXTRA_ARGS
- -DMIOPEN_USE_CLANG_TIDY
+ -DCK_USE_CLANG_TIDY
)
include(CppCheck)
@@ -180,19 +202,74 @@ enable_cppcheck(
unmatchedSuppression
FORCE
SOURCES
- host/host_tensor/src
- host/driver_offline/src
- composable_kernel/src/kernel_wrapper
+ library/src
INCLUDE
- host/host_tensor/include
- host/solver/include
- host/driver_offline/include
- composable_kernel/include/*
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/include
+ ${CMAKE_CURRENT_SOURCE_DIR}/library/include
DEFINE
CPPCHECK=1
__linux__=1
)
-add_subdirectory(host)
+set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
+set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
+
+include_directories(BEFORE
+ ${PROJECT_SOURCE_DIR}/include
+ ${PROJECT_SOURCE_DIR}/library/include
+)
+
+
+SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
+if(BUILD_DEV)
+ add_compile_options(-Werror)
+ add_compile_options(-Weverything)
+endif()
+message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
+
+add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
+
+rocm_package_setup_component(tests
+ LIBRARY_NAME composablekernel
+ PACKAGE_NAME tests # Prevent -static suffix on package name
+)
+
+add_subdirectory(library)
+add_subdirectory(example)
+add_subdirectory(test)
+add_subdirectory(profiler)
+
+#Create an interface target for the include only files and call it "composablekernels"
+include(CMakePackageConfigHelpers)
+
+set(version 1.0.0)
+write_basic_package_version_file(
+ "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
+ VERSION "${version}"
+ COMPATIBILITY AnyNewerVersion
+)
+
+configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
+ "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
+ INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
+ NO_CHECK_REQUIRED_COMPONENTS_MACRO
+)
+
+rocm_install(FILES
+ "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
+ "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
+ DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
+)
+
+set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
+set(CPACK_RPM_PACKAGE_LICENSE "MIT")
+
+rocm_create_package(
+ NAME composablekernel
+ DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
+ MAINTAINER "MIOpen Kernels Dev Team
"
+ LDCONFIG
+ HEADER_ONLY
+)
diff --git a/Config.cmake.in b/Config.cmake.in
new file mode 100644
index 00000000..12b5c331
--- /dev/null
+++ b/Config.cmake.in
@@ -0,0 +1,11 @@
+@PACKAGE_INIT@
+
+set(_composable_kernel_supported_components device_operations host_tensor)
+
+foreach(_comp ${composable_kernel_FIND_COMPONENTS})
+ if(NOT _comp IN_LIST _composable_kernel_supported_components)
+ set(composable_kernel_FOUND False)
+ set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
+ endif()
+ include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
+endforeach()
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 00000000..0d32b52f
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,95 @@
+FROM ubuntu:18.04
+
+ARG ROCMVERSION=5.1
+ARG OSDB_BKC_VERSION
+
+RUN set -xe
+
+ARG BUILD_THREADS=8
+ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
+# Add rocm repository
+RUN apt-get update
+RUN apt-get install -y wget gnupg
+RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
+RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
+RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
+RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list"
+
+# ADD requirements.txt requirements.txt
+# Install dependencies
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
+ apt-utils \
+ build-essential \
+ cmake-data=3.15.1-0kitware1 \
+ cmake=3.15.1-0kitware1 \
+ curl \
+ g++ \
+ gdb \
+ git \
+ hip-rocclr \
+ jq \
+ libelf-dev \
+ libncurses5-dev \
+ libnuma-dev \
+ libpthread-stubs0-dev \
+ llvm-amdgpu \
+ pkg-config \
+ python \
+ python3.8 \
+ python-dev \
+ python3-dev \
+ python-pip \
+ python3-pip \
+ software-properties-common \
+ wget \
+ rocm-dev \
+ rocm-device-libs \
+ rocm-cmake \
+ vim \
+ zlib1g-dev \
+ openssh-server \
+ clang-format-10 \
+ kmod && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# Setup ubsan environment to printstacktrace
+RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
+ENV UBSAN_OPTIONS=print_stacktrace=1
+
+# Install an init system
+RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb
+RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb
+
+# Install cget
+RUN pip install cget
+
+# Install rclone
+RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
+
+ARG PREFIX=/opt/rocm
+# Install dependencies
+RUN cget install pfultz2/rocm-recipes
+# Install rbuild
+RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/6d78a0553babdaea8d2da5de15cbda7e869594b8.tar.gz
+# Install packages for processing the performance results
+RUN pip3 install --upgrade pip
+RUN pip3 install sqlalchemy
+RUN pip3 install pymysql
+RUN pip3 install pandas
+RUN pip3 install setuptools-rust
+RUN pip3 install sshtunnel
+# Setup ubsan environment to printstacktrace
+ENV UBSAN_OPTIONS=print_stacktrace=1
+
+ENV LC_ALL=C.UTF-8
+ENV LANG=C.UTF-8
+ADD rbuild.ini /rbuild.ini
+ADD dev-requirements.txt dev-requirements.txt
+RUN rbuild prepare -s develop -d $PREFIX
+RUN groupadd -f render
+
+# Install the new rocm-cmake version
+RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \
+ cd rocm-cmake && mkdir build && cd build && \
+ cmake .. && cmake --build . && cmake --build . --target install
diff --git a/Jenkinsfile b/Jenkinsfile
new file mode 100644
index 00000000..15be3e54
--- /dev/null
+++ b/Jenkinsfile
@@ -0,0 +1,441 @@
+def rocmnode(name) {
+ return 'rocmtest && miopen && ' + name
+}
+
+def show_node_info() {
+ sh """
+ echo "NODE_NAME = \$NODE_NAME"
+ lsb_release -sd
+ uname -r
+ ls /opt/ -la
+ """
+}
+
+def cmake_build(Map conf=[:]){
+
+ def compiler = conf.get("compiler","/opt/rocm/bin/hipcc")
+ def config_targets = conf.get("config_targets","check")
+ def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined " + conf.get("extradebugflags", "")
+ def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","")
+ def prefixpath = conf.get("prefixpath","/opt/rocm")
+ def setup_args = conf.get("setup_args","")
+
+ if (prefixpath != "/usr/local"){
+ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} "
+ }
+
+ def build_type_debug = (conf.get("build_type",'release') == 'debug')
+
+ //cmake_env can overwrite default CXX variables.
+ def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
+
+ def package_build = (conf.get("package_build","") == "true")
+
+ if (package_build == true) {
+ config_targets = "package"
+ }
+
+ if(conf.get("build_install","") == "true")
+ {
+ config_targets = 'install ' + config_targets
+ setup_args = ' -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install' + setup_args
+ } else{
+ setup_args = ' -DBUILD_DEV=On' + setup_args
+ }
+
+ if(build_type_debug){
+ setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args
+ }else{
+ setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args
+ }
+
+ def pre_setup_cmd = """
+ echo \$HSA_ENABLE_SDMA
+ ulimit -c unlimited
+ rm -rf build
+ mkdir build
+ rm -rf install
+ mkdir install
+ cd build
+ """
+ def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
+ // reduce parallelism when compiling, clang uses too much memory
+ def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 1 )) ${config_targets}")
+ def execute_cmd = conf.get("execute_cmd", "")
+
+ def cmd = conf.get("cmd", """
+ ${pre_setup_cmd}
+ ${setup_cmd}
+ ${build_cmd}
+ ${execute_cmd}
+ """)
+
+ echo cmd
+ sh cmd
+
+ // Only archive from master or develop
+ if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) {
+ archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
+ }
+}
+
+def buildHipClangJob(Map conf=[:]){
+ show_node_info()
+
+ env.HSA_ENABLE_SDMA=0
+ checkout scm
+
+ def image = "composable_kernels"
+ def prefixpath = conf.get("prefixpath", "/opt/rocm")
+ def gpu_arch = conf.get("gpu_arch", "gfx908")
+
+ // Jenkins is complaining about the render group
+ // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
+ def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
+ if (conf.get("enforce_xnack_on", false)) {
+ dockerOpts = dockerOpts + " --env HSA_XNACK=1"
+ }
+ def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
+
+ def variant = env.STAGE_NAME
+
+ def retimage
+
+ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
+ if (params.USE_DOCKERFILE){
+ try {
+ retimage = docker.build("${image}", dockerArgs + '.')
+ withDockerContainer(image: image, args: dockerOpts) {
+ timeout(time: 5, unit: 'MINUTES')
+ {
+ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
+ }
+ }
+ }
+ catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
+ echo "The job was cancelled or aborted"
+ throw e
+ }
+ catch(Exception ex) {
+ retimage = docker.build("${image}", dockerArgs + "--no-cache .")
+ withDockerContainer(image: image, args: dockerOpts) {
+ timeout(time: 5, unit: 'MINUTES')
+ {
+ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
+ }
+ }
+ }
+ }
+ else{
+ timeout(time: 3, unit: 'HOURS'){
+ retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull()
+ image="b56f8ac0d6ea"
+ sh "docker images"
+ }
+ }
+
+ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
+ timeout(time: 5, unit: 'HOURS')
+ {
+ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
+ cmake_build(conf)
+ }
+ }
+ }
+ return retimage
+}
+
+def reboot(){
+ build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),]
+}
+
+
+
+
+
+def buildHipClangJobAndReboot(Map conf=[:]){
+ try{
+ buildHipClangJob(conf)
+ }
+ catch(e){
+ echo "throwing error exception for the stage"
+ echo 'Exception occurred: ' + e.toString()
+ throw e
+ }
+ finally{
+ if (!conf.get("no_reboot", false)) {
+ reboot()
+ }
+ }
+}
+
+
+def runCKProfiler(Map conf=[:]){
+ show_node_info()
+
+ env.HSA_ENABLE_SDMA=0
+ checkout scm
+
+ def image = "composable_kernels"
+ def prefixpath = conf.get("prefixpath", "/opt/rocm")
+ def gpu_arch = conf.get("gpu_arch", "gfx908")
+
+ // Jenkins is complaining about the render group
+ // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
+ def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
+ if (conf.get("enforce_xnack_on", false)) {
+ dockerOpts = dockerOpts + " --env HSA_XNACK=1"
+ }
+ def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
+
+ def variant = env.STAGE_NAME
+
+ def retimage
+
+ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
+ if (params.USE_DOCKERFILE){
+ try {
+ retimage = docker.build("${image}", dockerArgs + '.')
+ withDockerContainer(image: image, args: dockerOpts) {
+ timeout(time: 5, unit: 'MINUTES')
+ {
+ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
+ }
+ }
+ }
+ catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
+ echo "The job was cancelled or aborted"
+ throw e
+ }
+ catch(Exception ex) {
+ retimage = docker.build("${image}", dockerArgs + "--no-cache .")
+ withDockerContainer(image: image, args: dockerOpts) {
+ timeout(time: 5, unit: 'MINUTES')
+ {
+ sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
+ }
+ }
+ }
+ }
+ else{
+ timeout(time: 3, unit: 'HOURS'){
+ retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull()
+ image="b56f8ac0d6ea"
+ sh "docker images"
+ }
+ }
+
+ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
+ timeout(time: 5, unit: 'HOURS')
+ {
+ cmake_build(conf)
+ dir("script"){
+ //run gemm performance tests
+ def gemm_log = "perf_gemm_${gpu_arch}.log"
+ sh "rm -f ${gemm_log}"
+ sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
+ sh "echo Node name: ${NODE_NAME} >> ${gemm_log}"
+ sh "echo GPU_arch name: ${gpu_arch} >> ${gemm_log}"
+ sh "rocminfo | grep 'Compute Unit:' >> ${gemm_log} "
+ sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
+ sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
+ sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
+ sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
+ //results will be parsed, stored, and analyzed within the python script
+ //the script will return 0 if the performance criteria are met
+ //or return 1 if the criteria are not met
+ archiveArtifacts "${gemm_log}"
+ sh "python3 parse_perf_data.py ${gemm_log} "
+ //run resnet50 test
+ def resnet_log = "perf_resnet50_${gpu_arch}.log"
+ sh "rm -f ${resnet_log}"
+ sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}"
+ sh "echo Node name: ${NODE_NAME} >> ${resnet_log}"
+ sh "echo GPU_arch name: ${gpu_arch} >> ${resnet_log}"
+ sh "rocminfo | grep 'Compute Unit:' >> ${resnet_log} "
+ sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}"
+ sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}"
+ //first run tests with N=256
+ sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}"
+ //then run with N=4
+ sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}"
+ archiveArtifacts "${resnet_log}"
+ //the script will put the results from N=256 and N=4 runs into separate tables
+ sh "python3 parse_perf_data.py ${resnet_log} "
+ }
+ }
+ }
+ }
+ return retimage
+}
+
+
+def runPerfTest(Map conf=[:]){
+ try{
+ runCKProfiler(conf)
+ }
+ catch(e){
+ echo "throwing error exception in performance tests"
+ echo 'Exception occurred: ' + e.toString()
+ throw e
+ }
+ finally{
+ if (!conf.get("no_reboot", false)) {
+ reboot()
+ }
+ }
+}
+
+pipeline {
+ agent none
+ options {
+ parallelsAlwaysFailFast()
+ }
+ parameters {
+ booleanParam(
+ name: "USE_DOCKERFILE",
+ defaultValue: true,
+ description: "")
+ }
+ environment{
+ dbuser = "${dbuser}"
+ dbpassword = "${dbpassword}"
+ dbsship = "${dbsship}"
+ dbsshport = "${dbsshport}"
+ dbsshuser = "${dbsshuser}"
+ dbsshpassword = "${dbsshpassword}"
+ status_wrapper_creds = "${status_wrapper_creds}"
+ }
+ stages{
+ stage("Static checks") {
+ parallel{
+ // enable after we move from hipcc to hip-clang
+ // stage('Tidy') {
+ // agent{ label rocmnode("nogpu") }
+ // environment{
+ // // setup_cmd = "CXX='/opt/rocm/bin/hipcc' cmake -DBUILD_DEV=On .. "
+ // build_cmd = "make -j\$(nproc) -k analyze"
+ // }
+ // steps{
+ // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
+ // }
+ // }
+ stage('Clang Format') {
+ agent{ label rocmnode("nogpu") }
+ environment{
+ execute_cmd = "find .. -iname \'*.h\' \
+ -o -iname \'*.hpp\' \
+ -o -iname \'*.cpp\' \
+ -o -iname \'*.h.in\' \
+ -o -iname \'*.hpp.in\' \
+ -o -iname \'*.cpp.in\' \
+ -o -iname \'*.cl\' \
+ | grep -v 'build/' \
+ | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'"
+ }
+ steps{
+ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
+ }
+ }
+ }
+ }
+ stage("Tests")
+ {
+ parallel
+ {
+ stage("Run Tests: gfx908")
+ {
+ agent{ label rocmnode("gfx908")}
+ environment{
+ setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
+ }
+ steps{
+ buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908")
+ }
+ }
+ stage("Run Tests: gfx90a")
+ {
+ agent{ label rocmnode("gfx90a")}
+ environment{
+ setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
+ }
+ steps{
+ buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a")
+ }
+ }
+ }
+ }
+ stage("Client App")
+ {
+ parallel
+ {
+ stage("Run Client App")
+ {
+ agent{ label rocmnode("gfx908")}
+ environment{
+ setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
+ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc .. && make -j """
+ }
+ steps{
+ buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
+ }
+ }
+ }
+ }
+ stage("Performance Tests")
+ {
+ parallel
+ {
+ stage("Run ckProfiler: gfx908")
+ {
+ agent{ label rocmnode("gfx908")}
+ environment{
+ setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
+ }
+ steps{
+ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908")
+ }
+ }
+ stage("Run ckProfiler: gfx90a")
+ {
+ agent{ label rocmnode("gfx90a")}
+ environment{
+ setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
+ }
+ steps{
+ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a")
+ }
+ }
+ }
+ }
+ /* enable after the cmake file supports packaging
+ stage("Packages") {
+ when {
+ expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA }
+ }
+ parallel {
+ stage("Package /opt/rocm") {
+ agent{ label rocmnode("nogpu") }
+ steps{
+ buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a")
+ }
+ }
+ }
+ }
+ */
+ }
+}
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..2fe9a845
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,28 @@
+Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang)
+Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang)
+Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan)
+Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang)
+Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah)
+Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
+Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
+
+SPDX-License-Identifier: MIT
+Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 4f071d58..bbc4d2bc 100644
--- a/README.md
+++ b/README.md
@@ -1,177 +1,66 @@
-# How to build and run
-
-# Docker
-```
-docker run \
--it \
---rm \
---privileged \
---group-add sudo \
--w /root/workspace \
--v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
-rocm/tensorflow:rocm4.2-tf2.4-dev \
+## Docker script
+```bash
+docker run \
+-it \
+--privileged \
+--group-add sudo \
+-w /root/workspace \
+-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
+rocm/tensorflow:rocm5.1-tf2.6-dev \
/bin/bash
```
-# Install Boost for online compilation
-https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install
+# Install newer version of rocm-cmake
+https://github.com/RadeonOpenCompute/rocm-cmake
-
-# Build
-Add path of Boost
-```
- export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
-```
-
-```
+## Build
+```bash
mkdir build && cd build
```
-cmake cmd. Need to Specify target ID, example below is gfx908
-```
-cmake \
--D CMAKE_BUILD_TYPE=Release \
--D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
--D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \
--D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
--D CMAKE_PREFIX_PATH=/opt/rocm \
--D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
+```bash
+# Need to specify target ID, example below is gfx908 and gfx90a
+cmake \
+-D BUILD_DEV=OFF \
+-D CMAKE_BUILD_TYPE=Release \
+-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \
+-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
+-D CMAKE_PREFIX_PATH=/opt/rocm \
+-D CMAKE_INSTALL_PREFIX=${PATH_TO_CK_INSTALL_DIRECTORY} \
..
```
-Build drivers: \
-``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \
-``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \
-``conv_fwd_driver_online`` is (online compilation) driver for forward convolution
-```
- make -j conv_fwd_driver_offline
- make -j conv_bwd_driver_offline
- make -j conv_fwd_driver_online
+### Build and Run Examples
+```bash
+ make -j examples
```
+Instructions for running each individual examples are under ```example/```
-# Run
-* layout: 0 = NCHW; 1 = NHWC
-* algo: algorithm
-* verify: 0 = no verification; 1 = do verification
-* init: 0 ~ 5. initialization method
-* log: 0 = no log; 1 = do log
-* repeat: number of time kernel being launched
-```
-######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
- ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
- ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
- ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
- ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
- ./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
+## Tests
+```bash
+ make -j examples tests
+ make test
```
-# Result
-Forward convoltuion, FP16, NCHW
+## Build ckProfiler
+```bash
+ make -j ckProfiler
```
-./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
+Instructions for running ckProfiler are under ```profiler/```
-layout: 0
-in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1}
-wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1}
-out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1}
-InLeftPads size 2, {1, 1, }
-InRightPads size 2, {1, 1, }
-ConvStrides size 2, {2, 2, }
-ConvDilations size 2, {1, 1, }
-device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
-a_k0_m_k1_grid_desc{216, 256, 8}
-b_k0_n_k1_grid_desc{216, 165888, 8}
-c_m_n_grid_desc{ 256, 165888}
-launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
-Warm up
-Start running 1 times...
-Average time : 1.4155 ms, 103.686 TFlop/s
+## Install CK
+```bash
+make install
```
-Forward convoltuion, FP16, NCHW
-```
- ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
-
- layout: 0
-in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1}
-wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1}
-out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1}
-InLeftPads size 2, {1, 1, }
-InRightPads size 2, {1, 1, }
-ConvStrides size 2, {1, 1, }
-ConvDilations size 2, {1, 1, }
-device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
-a_k0_m_k1_grid_desc{288, 1024, 8}
-b_k0_n_k1_grid_desc{288, 50176, 8}
-c_m_n_grid_desc{ 1024, 50176}
-launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
-Warm up
-Start running 1 times...
-Average time : 2.21357 ms, 106.959 TFlop/s
- ```
-
- Forward convolution, FP16, NHWC
- ```
- ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
-
- layout: 1
-in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1}
-wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1}
-out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1}
-InLeftPads size 2, {1, 1, }
-InRightPads size 2, {1, 1, }
-ConvStrides size 2, {2, 2, }
-ConvDilations size 2, {1, 1, }
-device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
-a_k0_m_k1_grid_desc{216, 165888, 8}
-b_k0_n_k1_grid_desc{216, 256, 8}
-c_m_n_grid_desc{ 165888, 256}
-launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
-Warm up
-Start running 1 times...
-Average time : 1.12014 ms, 131.025 TFlop/s
- ```
-
- Forward convolution, FP16, NHWC
- ```
- ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
-
- layout: 1
-in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
-wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1}
-out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
-InLeftPads size 2, {1, 1, }
-InRightPads size 2, {1, 1, }
-ConvStrides size 2, {1, 1, }
-ConvDilations size 2, {1, 1, }
-device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
-a_k0_m_k1_grid_desc{288, 50176, 8}
-b_k0_n_k1_grid_desc{288, 1024, 8}
-c_m_n_grid_desc{ 50176, 1024}
-launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
-Warm up
-Start running 1 times...
-Average time : 1.86877 ms, 126.693 TFlop/s
- ```
-
- Backward data convolution, FP16, NHWC
- ```
- ./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
-
- layout: 1
-in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
-wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1}
-out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
-InLeftPads size 2, {1, 1, }
-InRightPads size 2, {1, 1, }
-ConvStrides size 2, {1, 1, }
-ConvDilations size 2, {1, 1, }
-device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
-a_k0_m_k1_grid_desc{288, 50176, 8}
-b_k0_n_k1_grid_desc{288, 1024, 8}
-c_m_n_grid_desc{ 50176, 1024}
-launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
-Warm up
-Start running 1 times...
-Average time : 2.22461 ms, 106.428 TFlop/s
-```
+## Using CK as pre-built kernel library
+Instructions for using CK as a pre-built kernel library are under ```client_example/```
+
+## Caveat
+### Kernel Timing and Verification
+CK's own kernel timer will warn up kernel once, and then run it multiple times
+to get average kernel time. For some kernels that use atomic add, this will cause
+output buffer to be accumulated multiple times, causing verfication failure.
+To work around it, do not use CK's own timer and do verification at the same time.
+CK's own timer and verification in each example and ckProfiler can be enabled or
+disabled from command line.
diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt
new file mode 100644
index 00000000..9e741192
--- /dev/null
+++ b/client_example/01_gemm/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_executable(client_gemm gemm.cpp)
+target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations)
diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp
new file mode 100644
index 00000000..9b7b7a66
--- /dev/null
+++ b/client_example/01_gemm/gemm.cpp
@@ -0,0 +1,218 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using Row = ck::tensor_layout::gemm::RowMajor;
+using Col = ck::tensor_layout::gemm::ColumnMajor;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CElementOp = PassThrough;
+
+using ADataType = F16;
+using BDataType = F16;
+using CDataType = F16;
+
+using ALayout = Row;
+using BLayout = Col;
+using CLayout = Row;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // GEMM shape
+ ck::index_t M = 3840;
+ ck::index_t N = 4096;
+ ck::index_t K = 4096;
+
+ ck::index_t StrideA = 4096;
+ ck::index_t StrideB = 4096;
+ ck::index_t StrideC = 4096;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 5)
+ {
+ M = std::stoi(argv[1]);
+ N = std::stoi(argv[2]);
+ K = std::stoi(argv[3]);
+
+ StrideA = std::stoi(argv[4]);
+ StrideB = std::stoi(argv[5]);
+ StrideC = std::stoi(argv[6]);
+ }
+ else
+ {
+ printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n");
+ exit(0);
+ }
+
+ auto f_matrix_space_size =
+ [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
+ using Layout = decltype(layout);
+
+ if(std::is_same::value)
+ {
+ return (nRow - 1) * stride + nCol;
+ }
+ else
+ {
+ return (nCol - 1) * stride + nRow;
+ }
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
+ SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
+
+ using DeviceOp =
+ ck::tensor_operation::device::DeviceGemm;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto c_element_op = CElementOp{};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ c_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ a_element_op,
+ b_element_op,
+ c_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ // run the best intance
+ {
+ auto& op_ptr = op_ptrs[best_op_id];
+
+ std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
+ << std::endl;
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ c_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ a_element_op,
+ b_element_op,
+ c_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
+ }
+
+ std::cout << "Done" << std::endl;
+ }
+
+ return 0;
+}
diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
new file mode 100644
index 00000000..1064abc8
--- /dev/null
+++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
+target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations)
diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp
new file mode 100644
index 00000000..dbf2e634
--- /dev/null
+++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp
@@ -0,0 +1,241 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using Row = ck::tensor_layout::gemm::RowMajor;
+using Col = ck::tensor_layout::gemm::ColumnMajor;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = AddAddFastGelu;
+
+using ADataType = F16;
+using BDataType = F16;
+using D0DataType = F16;
+using D1DataType = F16;
+using EDataType = F16;
+
+using ALayout = Row;
+using BLayout = Col;
+using DDELayout = Row;
+using DDELayout = Row;
+using DELayout = Row;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // GEMM shape
+ ck::index_t M = 3840;
+ ck::index_t N = 4096;
+ ck::index_t K = 4096;
+
+ ck::index_t StrideA = 4096;
+ ck::index_t StrideB = 4096;
+ ck::index_t StrideD0 = 0;
+ ck::index_t StrideD1 = 4096;
+ ck::index_t StrideE = 4096;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 9)
+ {
+ M = std::stoi(argv[1]);
+ N = std::stoi(argv[2]);
+ K = std::stoi(argv[3]);
+
+ StrideA = std::stoi(argv[4]);
+ StrideB = std::stoi(argv[5]);
+ StrideD0 = std::stoi(argv[6]);
+ StrideD1 = std::stoi(argv[7]);
+ StrideE = std::stoi(argv[8]);
+ }
+ else
+ {
+ printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
+ exit(0);
+ }
+
+ auto f_matrix_space_size =
+ [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
+ using Layout = decltype(layout);
+
+ if(std::is_same::value)
+ {
+ return (nRow - 1) * stride + nCol;
+ }
+ else
+ {
+ return (nCol - 1) * stride + nRow;
+ }
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
+ SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) *
+ f_matrix_space_size(M, N, StrideD0, DDELayout{}));
+ SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) *
+ f_matrix_space_size(M, N, StrideD1, DDELayout{}));
+ SimpleDeviceMem e_device_buf(sizeof(EDataType) *
+ f_matrix_space_size(M, N, StrideE, DELayout{}));
+
+ using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
+ ALayout,
+ BLayout,
+ DDELayout,
+ ADataType,
+ BDataType,
+ ck::Tuple,
+ EDataType,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::AddAddFastGelu>;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto cde_element_op = CDEElementOp{};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(
+ a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d0_m_n_device_buf.GetDeviceBuffer(),
+ d1_m_n_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ std::array{StrideD0, StrideD1},
+ StrideE,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ // run the best intance
+ {
+ auto& op_ptr = op_ptrs[best_op_id];
+
+ std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
+ << std::endl;
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(
+ a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d0_m_n_device_buf.GetDeviceBuffer(),
+ d1_m_n_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ std::array{StrideD0, StrideD1},
+ StrideE,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
+ }
+
+ std::cout << "Done" << std::endl;
+ }
+
+ return 0;
+}
diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt
new file mode 100644
index 00000000..3742e708
--- /dev/null
+++ b/client_example/03_gemm_layernorm/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
+target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp
new file mode 100644
index 00000000..8f142937
--- /dev/null
+++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp
@@ -0,0 +1,271 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
+#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp"
+#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp"
+
+using F16 = ck::half_t;
+using F32 = float;
+
+using ADataType = F16;
+using BDataType = F16;
+using BiasDataType = F32;
+using CDataType = F16;
+using D0DataType = F16;
+using ReduceDataType = F32;
+using GammaDataType = F16;
+using BetaDataType = F16;
+using LayerNormOutDataType = F16;
+
+using ALayout = ck::tensor_layout::gemm::RowMajor;
+using BLayout = ck::tensor_layout::gemm::ColumnMajor;
+using CLayout = ck::tensor_layout::gemm::RowMajor;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+template
+bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op,
+ const void* p_a,
+ const void* p_b,
+ const void* p_bias,
+ const void* p_d0,
+ void* p_c,
+ void* p_mean,
+ void* p_square_mean,
+ int M,
+ int N,
+ int K,
+ int StrideA,
+ int StrideB,
+ int StrideC,
+ int StrideD0,
+ bool time_kernel)
+{
+ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+ using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
+ using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
+
+ auto passOp = PassThrough{};
+ auto squareOp = UnarySquareElementOp{};
+ auto divOp = UnaryDivElementOp{N};
+
+ auto argument_ptr =
+ p_op->MakeArgumentPointer(p_a,
+ p_b,
+ p_bias,
+ {p_d0},
+ p_c,
+ {p_mean, p_square_mean},
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ {StrideD0},
+ {&passOp, &passOp, &passOp}, // functor for a, b, c
+ {&passOp}, // functor for d0
+ {&passOp, &squareOp}, // functor for inputs of reduction
+ {&divOp, &divOp}); // functor for outputs of reduction
+
+ if(p_op->IsSupportedArgument(argument_ptr.get()))
+ {
+ auto invoker_ptr = p_op->MakeInvokerPointer();
+
+ // If we evaluate running time of gemm_reduce. The output may wrong.
+ // Because we need to initialize the reduction tensor before runing the kernel.
+ // However we run kernel many times for time_kernel = trie without reinitialize the out
+ // of reduction tensor.
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
+
+ if(time_kernel)
+ std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
+
+ return true;
+ }
+
+ return false;
+}
+
+template
+bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
+ const void* p_x,
+ const void* p_mean,
+ const void* p_square_mean,
+ const void* p_gamma,
+ const void* p_beta,
+ void* p_y,
+ int M,
+ int N,
+ int StrideX,
+ bool time_kernel)
+{
+ std::array input = {p_x, p_mean, p_square_mean, p_gamma, p_beta};
+ std::array output = {p_y};
+ auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
+
+ auto argument_ptr = p_op->MakeArgumentPointer(input,
+ output,
+ {M, N},
+ {{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
+ {{StrideX, 1}},
+ ck::tensor_operation::element_wise::Normalize{});
+
+ if(p_op->IsSupportedArgument(argument_ptr.get()))
+ {
+ auto invoker_ptr = p_op->MakeInvokerPointer();
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
+
+ if(time_kernel)
+ std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
+
+ return true;
+ }
+
+ return false;
+}
+
+int main()
+{
+ ck::index_t M = 1024;
+ ck::index_t N = 1024;
+ ck::index_t K = 1024;
+
+ ck::index_t StrideA = 1024;
+ ck::index_t StrideB = 1024;
+ ck::index_t StrideC = 1024;
+ ck::index_t StrideD0 = 1024;
+
+ const auto gemm_reduce_ptrs =
+ ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances<
+ ADataType,
+ BDataType,
+ CDataType,
+ ALayout,
+ BLayout,
+ CLayout>();
+
+ const auto normalize_ptrs =
+ ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
+ CDataType,
+ ReduceDataType,
+ ReduceDataType,
+ GammaDataType,
+ BetaDataType,
+ LayerNormOutDataType>();
+
+ std::cout << "found " << gemm_reduce_ptrs.size()
+ << " gemm_reduceMean_reduceSquareMean instances" << std::endl;
+
+ std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
+
+ auto f_matrix_space_size =
+ [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
+ using Layout = decltype(layout);
+
+ if(std::is_same::value)
+ {
+ return (nRow - 1) * stride + nCol;
+ }
+ else
+ {
+ return (nCol - 1) * stride + nRow;
+ }
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
+ SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N);
+ SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
+ SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
+ f_matrix_space_size(M, N, StrideD0, CLayout{}));
+ SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M);
+ SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M);
+ SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
+ SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
+ SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N);
+
+ bool b_time_kernel = true;
+ bool b_only_run_first_kernel = true;
+
+ // layernorm => (1) + (2)
+ // (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c)
+ // (2). normalize(c, mean, square_mean, gamma, beta)
+ for(auto& gemm_reduce_ptr : gemm_reduce_ptrs)
+ {
+ // run first available kernel
+ if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr,
+ a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ bias_device_buf.GetDeviceBuffer(),
+ d0_device_buf.GetDeviceBuffer(),
+ c_device_buf.GetDeviceBuffer(),
+ reduceMean_device_buf.GetDeviceBuffer(),
+ reduceMeanSquare_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ K,
+ StrideA,
+ StrideB,
+ StrideC,
+ StrideD0,
+ b_time_kernel))
+ {
+ if(b_only_run_first_kernel)
+ break;
+ }
+ else
+ {
+ std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem"
+ << std::endl;
+ }
+ }
+
+ for(auto& normalize_ptr : normalize_ptrs)
+ {
+ if(RunDeviceNormalize2D(normalize_ptr,
+ c_device_buf.GetDeviceBuffer(),
+ reduceMean_device_buf.GetDeviceBuffer(),
+ reduceMeanSquare_device_buf.GetDeviceBuffer(),
+ gamma_device_buf.GetDeviceBuffer(),
+ beta_device_buf.GetDeviceBuffer(),
+ layerNorm_device_buf.GetDeviceBuffer(),
+ M,
+ N,
+ StrideC,
+ b_time_kernel))
+ {
+ if(b_only_run_first_kernel)
+ break;
+ }
+ else
+ {
+ std::cout << normalize_ptr->GetTypeString() << " does not support this problem"
+ << std::endl;
+ }
+ }
+}
diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt
new file mode 100644
index 00000000..4bc6780f
--- /dev/null
+++ b/client_example/04_contraction/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_executable(client_contraction_scale contraction_scale.cpp)
+target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations)
+
+add_executable(client_contraction_bilinear contraction_bilinear.cpp)
+target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations)
+
diff --git a/client_example/04_contraction/contraction_bilinear.cpp b/client_example/04_contraction/contraction_bilinear.cpp
new file mode 100644
index 00000000..b71c51c0
--- /dev/null
+++ b/client_example/04_contraction/contraction_bilinear.cpp
@@ -0,0 +1,241 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
+
+using F32 = float;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+using Bilinear = ck::tensor_operation::element_wise::Bilinear;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = Bilinear;
+
+using ADataType = F32;
+using BDataType = F32;
+using AccDataType = F32;
+using CShuffleDataType = F32;
+using DDataType = F32;
+using DsDataType = ck::Tuple;
+using EDataType = F32;
+
+static constexpr ck::index_t NumDimM = 2;
+static constexpr ck::index_t NumDimN = 2;
+static constexpr ck::index_t NumDimK = 2;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // A[M0, M1, K0, K1]
+ std::vector a_ms_ks_lengths{30, 128, 32, 64};
+ std::vector a_ms_ks_strides{524288, 4096, 128, 1};
+ // B[N0, N1, K0, K1]
+ std::vector b_ns_ks_lengths{32, 64, 32, 64};
+ std::vector b_ns_ks_strides{524288, 4096, 128, 1};
+ // D[M0, M1, N0, N1]
+ std::vector d_ms_ns_lengths{30, 128, 32, 64};
+ std::vector d_ms_ns_strides{524288, 4096, 128, 1};
+ // E[M0, M1, N0, N1]
+ std::vector e_ms_ns_lengths{30, 128, 32, 64};
+ std::vector e_ms_ns_strides{524288, 4096, 128, 1};
+
+ float alpha = 1.f;
+ float beta = 1.f;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 25)
+ {
+ const ck::index_t M0 = std::stoi(argv[1]);
+ const ck::index_t M1 = std::stoi(argv[2]);
+
+ const ck::index_t N0 = std::stoi(argv[3]);
+ const ck::index_t N1 = std::stoi(argv[4]);
+
+ const ck::index_t K0 = std::stoi(argv[5]);
+ const ck::index_t K1 = std::stoi(argv[6]);
+
+ a_ms_ks_lengths = {M0, M1, K0, K1};
+ a_ms_ks_strides = {
+ std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
+
+ b_ns_ks_lengths = {N0, N1, K0, K1};
+ b_ns_ks_strides = {
+ std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
+
+ d_ms_ns_lengths = {M0, M1, N0, N1};
+ d_ms_ns_strides = {
+ std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
+
+ e_ms_ns_lengths = {M0, M1, N0, N1};
+ e_ms_ns_strides = {
+ std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])};
+
+ alpha = std::stof(argv[23]);
+ beta = std::stof(argv[24]);
+ }
+ else
+ {
+ printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
+ printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
+ printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
+ printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
+ printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
+ printf("arg23 to 24: alpha, beta\n");
+ exit(0);
+ }
+
+ auto f_tensor_space_size = [](auto lengths, auto strides) {
+ std::size_t space_size = 1;
+ for(std::size_t i = 0; i < lengths.size(); ++i)
+ {
+ space_size += (lengths[i] - 1) * strides[i];
+ }
+ return space_size;
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) *
+ f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) *
+ f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
+ SimpleDeviceMem d_device_buf(sizeof(DDataType) *
+ f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides));
+ SimpleDeviceMem e_device_buf(sizeof(EDataType) *
+ f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
+
+ using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
+ NumDimM,
+ NumDimN,
+ NumDimK,
+ ADataType,
+ BDataType,
+ ck::Tuple,
+ EDataType,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::Bilinear>;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto cde_element_op = CDEElementOp{alpha, beta};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr =
+ op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{d_device_buf.GetDeviceBuffer()},
+ e_device_buf.GetDeviceBuffer(),
+ a_ms_ks_lengths,
+ a_ms_ks_strides,
+ b_ns_ks_lengths,
+ b_ns_ks_strides,
+ std::array, 1>{d_ms_ns_lengths},
+ std::array, 1>{d_ms_ns_strides},
+ e_ms_ns_lengths,
+ e_ms_ns_strides,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
+ e_ms_ns_lengths.begin() + NumDimM,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
+ e_ms_ns_lengths.begin() + NumDimM + NumDimN,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
+ a_ms_ks_lengths.begin() + NumDimM + NumDimK,
+ ck::index_t{1},
+ std::multiplies{});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+ std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
+ sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ return 0;
+}
diff --git a/client_example/04_contraction/contraction_scale.cpp b/client_example/04_contraction/contraction_scale.cpp
new file mode 100644
index 00000000..5908c1d8
--- /dev/null
+++ b/client_example/04_contraction/contraction_scale.cpp
@@ -0,0 +1,227 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
+
+#include
+#include
+#include
+#include
+
+#include "ck/ck.hpp"
+#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
+#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
+#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
+
+#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp"
+
+using F32 = float;
+
+using PassThrough = ck::tensor_operation::element_wise::PassThrough;
+using Scale = ck::tensor_operation::element_wise::Scale;
+
+using AElementOp = PassThrough;
+using BElementOp = PassThrough;
+using CDEElementOp = Scale;
+
+using ADataType = F32;
+using BDataType = F32;
+using AccDataType = F32;
+using CShuffleDataType = F32;
+using DsDataType = ck::Tuple<>;
+using EDataType = F32;
+
+static constexpr ck::index_t NumDimM = 2;
+static constexpr ck::index_t NumDimN = 2;
+static constexpr ck::index_t NumDimK = 2;
+
+struct SimpleDeviceMem
+{
+ SimpleDeviceMem() = delete;
+
+ SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
+ {
+ (void)hipMalloc(static_cast(&p_mem_), mem_size);
+ }
+
+ void* GetDeviceBuffer() { return p_mem_; }
+
+ ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
+
+ void* p_mem_;
+};
+
+int main(int argc, char* argv[])
+{
+ // A[M0, M1, K0, K1]
+ std::vector a_ms_ks_lengths{30, 128, 32, 64};
+ std::vector a_ms_ks_strides{524288, 4096, 128, 1};
+ // B[N0, N1, K0, K1]
+ std::vector b_ns_ks_lengths{32, 64, 32, 64};
+ std::vector b_ns_ks_strides{524288, 4096, 128, 1};
+ // E[M0, M1, N0, N1]
+ std::vector e_ms_ns_lengths{30, 128, 32, 64};
+ std::vector e_ms_ns_strides{524288, 4096, 128, 1};
+
+ float scale = 1.f;
+
+ if(argc == 1)
+ {
+ // use default case
+ }
+ else if(argc == 20)
+ {
+ const ck::index_t M0 = std::stoi(argv[1]);
+ const ck::index_t M1 = std::stoi(argv[2]);
+
+ const ck::index_t N0 = std::stoi(argv[3]);
+ const ck::index_t N1 = std::stoi(argv[4]);
+
+ const ck::index_t K0 = std::stoi(argv[5]);
+ const ck::index_t K1 = std::stoi(argv[6]);
+
+ a_ms_ks_lengths = {M0, M1, K0, K1};
+ a_ms_ks_strides = {
+ std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])};
+
+ b_ns_ks_lengths = {N0, N1, K0, K1};
+ b_ns_ks_strides = {
+ std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])};
+
+ e_ms_ns_lengths = {M0, M1, N0, N1};
+ e_ms_ns_strides = {
+ std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])};
+
+ scale = std::stof(argv[19]);
+ }
+ else
+ {
+ printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n");
+ printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
+ printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
+ printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
+ printf("arg19: scale\n");
+ exit(0);
+ }
+
+ auto f_tensor_space_size = [](auto lengths, auto strides) {
+ std::size_t space_size = 1;
+ for(std::size_t i = 0; i < lengths.size(); ++i)
+ {
+ space_size += (lengths[i] - 1) * strides[i];
+ }
+ return space_size;
+ };
+
+ SimpleDeviceMem a_device_buf(sizeof(ADataType) *
+ f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides));
+ SimpleDeviceMem b_device_buf(sizeof(BDataType) *
+ f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides));
+ SimpleDeviceMem e_device_buf(sizeof(EDataType) *
+ f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides));
+
+ using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<
+ NumDimM,
+ NumDimN,
+ NumDimK,
+ ADataType,
+ BDataType,
+ ck::Tuple<>,
+ EDataType,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::PassThrough,
+ ck::tensor_operation::element_wise::Scale>;
+
+ // get device op instances
+ const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
+ DeviceOp>::GetInstances();
+
+ std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
+
+ const auto a_element_op = AElementOp{};
+ const auto b_element_op = BElementOp{};
+ const auto cde_element_op = CDEElementOp{scale};
+
+ std::string best_op_name;
+ bool found = false;
+ int best_op_id = -1;
+ float best_ave_time = 0;
+ float best_tflops = 0;
+ float best_gb_per_sec = 0;
+
+ // profile device operation instances
+ std::cout << "Run all instances and do timing" << std::endl;
+
+ for(int i = 0; i < op_ptrs.size(); ++i)
+ {
+ auto& op_ptr = op_ptrs[i];
+
+ auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
+ b_device_buf.GetDeviceBuffer(),
+ std::array{},
+ e_device_buf.GetDeviceBuffer(),
+ a_ms_ks_lengths,
+ a_ms_ks_strides,
+ b_ns_ks_lengths,
+ b_ns_ks_strides,
+ std::array, 0>{},
+ std::array, 0>{},
+ e_ms_ns_lengths,
+ e_ms_ns_strides,
+ a_element_op,
+ b_element_op,
+ cde_element_op);
+
+ auto invoker_ptr = op_ptr->MakeInvokerPointer();
+
+ std::string op_name = op_ptr->GetTypeString();
+
+ if(op_ptr->IsSupportedArgument(argument_ptr.get()))
+ {
+ float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
+
+ ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
+ e_ms_ns_lengths.begin() + NumDimM,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
+ e_ms_ns_lengths.begin() + NumDimM + NumDimN,
+ ck::index_t{1},
+ std::multiplies{});
+
+ ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
+ a_ms_ks_lengths.begin() + NumDimM + NumDimK,
+ ck::index_t{1},
+ std::multiplies{});
+
+ std::size_t flop = std::size_t(2) * M * N * K;
+ std::size_t num_btype =
+ sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
+
+ float tflops = static_cast(flop) / 1.E9 / ave_time;
+
+ float gb_per_sec = num_btype / 1.E6 / ave_time;
+
+ std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
+ << gb_per_sec << " GB/s, " << op_name << std::endl;
+
+ if(tflops > best_tflops)
+ {
+ found = true;
+ best_op_id = i;
+ best_op_name = op_name;
+ best_tflops = tflops;
+ best_ave_time = ave_time;
+ best_gb_per_sec = gb_per_sec;
+ }
+ }
+ else
+ {
+ std::cout << op_name << " does not support this problem" << std::endl;
+ }
+ }
+
+ std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
+ << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
+
+ return 0;
+}
diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt
new file mode 100644
index 00000000..3e04a185
--- /dev/null
+++ b/client_example/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 3.15)
+project(ck_app)
+add_compile_options(-std=c++17)
+
+find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
+find_package(hip REQUIRED PATHS /opt/rocm)
+message(STATUS "Build with HIP ${hip_VERSION}")
+
+add_subdirectory(01_gemm)
+add_subdirectory(02_gemm_add_add_fastgelu)
+add_subdirectory(03_gemm_layernorm)
+add_subdirectory(04_contraction)
diff --git a/client_example/README.md b/client_example/README.md
new file mode 100644
index 00000000..64a7130d
--- /dev/null
+++ b/client_example/README.md
@@ -0,0 +1,21 @@
+##
+Client application links to CK library, and therefore CK library needs to be installed before building client applications.
+
+
+## Build
+```bash
+mkdir -p client_example/build
+cd client_example/build
+```
+
+```bash
+cmake \
+-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
+-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \
+..
+```
+
+### Build client example
+```bash
+ make -j
+```
diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake
index 9f193b20..78133af0 100644
--- a/cmake/EnableCompilerWarnings.cmake
+++ b/cmake/EnableCompilerWarnings.cmake
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
- -Wno-sign-compare
+ -Wsign-compare
-Wno-extra-semi-stmt
)
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
diff --git a/cmake/TargetFlags.cmake b/cmake/TargetFlags.cmake
new file mode 100644
index 00000000..4f83fb5d
--- /dev/null
+++ b/cmake/TargetFlags.cmake
@@ -0,0 +1,50 @@
+
+function(get_target_property2 VAR TARGET PROPERTY)
+ get_target_property(_pflags ${TARGET} ${PROPERTY})
+ if(_pflags)
+ set(${VAR} ${_pflags} PARENT_SCOPE)
+ else()
+ set(${VAR} "" PARENT_SCOPE)
+ endif()
+endfunction()
+
+
+macro(append_flags FLAGS TARGET PROPERTY PREFIX)
+ get_target_property2(_pflags ${TARGET} ${PROPERTY})
+ foreach(FLAG ${_pflags})
+ if(TARGET ${FLAG})
+ target_flags(_pflags2 ${FLAG})
+ string(APPEND ${FLAGS} " ${_pflags2}")
+ else()
+ string(APPEND ${FLAGS} " ${PREFIX}${FLAG}")
+ endif()
+ endforeach()
+endmacro()
+
+macro(append_link_flags FLAGS TARGET PROPERTY)
+ get_target_property2(_pflags ${TARGET} ${PROPERTY})
+ foreach(FLAG ${_pflags})
+ if(TARGET ${FLAG})
+ target_flags(_pflags2 ${FLAG})
+ string(APPEND ${FLAGS} " ${_pflags2}")
+ elseif(FLAG MATCHES "^-.*")
+ string(APPEND ${FLAGS} " ${FLAG}")
+ elseif(EXISTS ${FLAG})
+ string(APPEND ${FLAGS} " ${FLAG}")
+ else()
+ string(APPEND ${FLAGS} " -l${FLAG}")
+ endif()
+ endforeach()
+endmacro()
+
+function(target_flags FLAGS TARGET)
+ set(_flags)
+ append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "")
+ append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D")
+ append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ")
+ append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ")
+ append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "")
+ append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "")
+ # message("_flags: ${_flags}")
+ set(${FLAGS} ${_flags} PARENT_SCOPE)
+endfunction()
diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake
new file mode 100644
index 00000000..3718b916
--- /dev/null
+++ b/cmake/googletest.cmake
@@ -0,0 +1,43 @@
+include(FetchContent)
+
+set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
+
+if(GOOGLETEST_DIR)
+ set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
+endif()
+
+message(STATUS "Fetching GoogleTest")
+
+list(APPEND GTEST_CMAKE_CXX_FLAGS
+ -Wno-undef
+ -Wno-reserved-identifier
+ -Wno-global-constructors
+ -Wno-missing-noreturn
+ -Wno-disabled-macro-expansion
+ -Wno-used-but-marked-unused
+ -Wno-switch-enum
+ -Wno-zero-as-null-pointer-constant
+ -Wno-unused-member-function
+ -Wno-comma
+ -Wno-old-style-cast
+)
+message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
+
+FetchContent_Declare(
+ googletest
+ GIT_REPOSITORY https://github.com/google/googletest.git
+ GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
+)
+
+# Will be necessary for windows build
+# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
+FetchContent_GetProperties(googletest)
+if(NOT googletest_POPULATED)
+ FetchContent_Populate(googletest)
+ add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
+endif()
+
+target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
+target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
+target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
+target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp
deleted file mode 100644
index 0a1e07ec..00000000
--- a/composable_kernel/include/gridwise_operation_wrapper.hpp
+++ /dev/null
@@ -1,14 +0,0 @@
-#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
-#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
-
-template
-__global__ void
-#if CK_USE_LAUNCH_BOUNDS
- __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
-#endif
- run_gridwise_operation(Xs... xs)
-{
- GridwiseOp{}.Run(xs...);
-}
-
-#endif
diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
deleted file mode 100644
index 5cc2f239..00000000
--- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+++ /dev/null
@@ -1,183 +0,0 @@
-#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
-#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
-
-#include "common_header.hpp"
-#include "threadwise_gemm_dlops_v3.hpp"
-
-namespace ck {
-
-template
-struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
-{
- struct MatrixIndex
- {
- index_t k;
- index_t h;
- index_t w;
- };
-
- // HACK: fix this @Jing Zhang
- static constexpr index_t KPerThreadSubC = 4;
-
- static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
- make_tuple(Number{}, Number{}));
-
- static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
- Number{}, Number<1>{}, Number{}, Number{}));
-
- static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
- Number{}, Number<1>{}, Number{}, Number{}));
-
- using AThreadCopy = ThreadwiseTensorSliceTransfer_v4,
- Sequence<0, 1>,
- 1,
- ThreadGemmADataPerRead_K,
- 1>;
-
- __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
- : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
- a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
- {
- static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
- BlockMatrixB::IsKnownAtCompileTime() &&
- ThreadMatrixC::IsKnownAtCompileTime(),
- "wrong! Desc should be known at compile-time");
-
- constexpr auto I0 = Number<0>{};
- constexpr auto I1 = Number<1>{};
- constexpr auto I2 = Number<2>{};
- constexpr auto I3 = Number<3>{};
-
- static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
- "wrong! K dimension not consistent\n");
-
- constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
- constexpr index_t H = BlockMatrixB{}.GetLength(I2);
- constexpr index_t W = BlockMatrixB{}.GetLength(I3);
-
- static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
- "wrong! Cannot evenly divide work among\n");
-
- constexpr auto KThreadCluster = K / KPerThread;
- constexpr auto HThreadCluster = H / HPerThread;
- constexpr auto WThreadCluster = W / WPerThread;
-
- static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
- "wrong! wrong blocksize\n");
- }
-
- __device__ static constexpr auto GetThreadMatrixCLengths()
- {
- return Sequence{};
- }
-
- __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
- {
- constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
- constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
-
- constexpr auto num_w_threads = W / WPerThread;
- constexpr auto num_h_threads = H / HPerThread;
- constexpr auto num_hw_threads = num_w_threads * num_h_threads;
-
- index_t k_thread_id = thread_id / num_hw_threads;
- index_t hw_thread_id = thread_id % num_hw_threads;
-
- index_t h_thread_id = hw_thread_id / num_w_threads;
- index_t w_thread_id = hw_thread_id % num_w_threads;
-
- return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
- }
-
- template
- __device__ void Run(const ABlockBuffer& a_block_buf,
- const BThreadBuffer& b_thread_buf,
- CThreadBuffer& c_thread_buf) const
- {
- static_assert(
- is_same, remove_cvref_t>::value &&
- is_same, remove_cvref_t>::value &&
- is_same, remove_cvref_t>::value &&
- "wrong! inconsistent type");
-
- constexpr auto I0 = Number<0>{};
-
- constexpr auto a_block_mtx = BlockMatrixA{};
-
- constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
-
- // HACK: fix this @Jing Zhang
- constexpr auto HoPerThreadSubC = 2;
- constexpr auto WoPerThreadSubC = 2;
-
- static_assert(KPerThread % KPerThreadSubC == 0, "");
- static_assert(HPerThread % HoPerThreadSubC == 0, "");
- static_assert(WPerThread % WoPerThreadSubC == 0, "");
-
- // thread A buffer for GEMM
- StaticBuffer
- a_thread_buf;
-
- constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{};
-
- static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
- static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
- a_thread_copy_.Run(a_block_mtx,
- make_tuple(e_begin, k_begin),
- a_block_buf,
- a_thread_mtx_,
- make_tuple(I0, I0),
- a_thread_buf);
-
- static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
- static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
- threadwise_gemm.Run(a_thread_buf,
- make_tuple(I0, I0),
- b_thread_buf,
- make_tuple(e_begin, I0, h_begin, w_begin),
- c_thread_buf,
- make_tuple(k_begin, I0, h_begin, w_begin));
- });
- });
- });
- });
- }
-
- template
- __device__ void MoveASliceWindow(const BlockMatrixA&,
- const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
- {
- a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
- }
-
- private:
- MatrixIndex c_thread_begin_mtx_idx_;
-
- AThreadCopy a_thread_copy_;
-};
-
-} // namespace ck
-#endif
diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
deleted file mode 100644
index 36c67832..00000000
--- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+++ /dev/null
@@ -1,282 +0,0 @@
-#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
-#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
-
-#include "common_header.hpp"
-#include "threadwise_tensor_slice_transfer.hpp"
-#include "xdlops_gemm.hpp"
-#include "tensor_adaptor.hpp"
-
-namespace ck {
-
-template
-struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
-{
- static constexpr auto I0 = Number<0>{};
- static constexpr auto I1 = Number<1>{};
- static constexpr auto I2 = Number<2>{};
- static constexpr auto I3 = Number<3>{};
-
- static constexpr index_t WaveSize = 64;
-
- static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
- static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
-
- static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
-
- static constexpr auto xdlops_gemm = XdlopsGemm{};
-
- static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
- static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
-
- StaticBufferV2, MRepeat * NRepeat, true>
- c_thread_buf_;
-
- __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
-
- __device__ static auto GetWaveIdx()
- {
- const index_t thread_id = get_thread_local_1d_id();
-
- constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
- make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
- make_tuple(Sequence<0, 1, 2>{}),
- make_tuple(Sequence<0>{}));
-
- return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
- }
-
- __device__ static auto CalculateAThreadOriginDataIndex()
- {
- const auto wave_idx = GetWaveIdx();
-
- const auto waveId_m = wave_idx[I0];
-
- const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
-
- return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
- }
-
- __device__ static auto CalculateBThreadOriginDataIndex()
- {
- const auto wave_idx = GetWaveIdx();
-
- const auto waveId_n = wave_idx[I1];
-
- const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
-
- return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
- }
-
- template
- __device__ static auto
- CalculateCThreadOriginDataIndex(Number, Number, Number, Number)
- {
- const auto wave_idx = GetWaveIdx();
-
- const auto waveId_m = wave_idx[I0];
- const auto waveId_n = wave_idx[I1];
-
- const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
-
- constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
- make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
- make_tuple(Sequence<0>{}),
- make_tuple(Sequence<0, 1, 2>{}));
-
- constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
- make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
- make_tuple(Sequence<0>{}),
- make_tuple(Sequence<0, 1, 2>{}));
-
- const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
- make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
- const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
- make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
-
- return make_tuple(c_thread_m, c_thread_n);
- }
-
- __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
- {
- static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
- BK0NK1BlockDesc::IsKnownAtCompileTime(),
- "wrong! Desc should be known at compile-time");
-
- static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
- "wrong! K0 dimension not consistent");
-
- static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
- "wrong! K1 dimension not consistent");
-
- static_assert(BlockSize == MWaves * NWaves * WaveSize,
- "BlockSize != MWaves * NWaves * WaveSize\n");
-
- static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
- "wrong!");
- }
-
- __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
- {
- constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
-
- constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
- constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
- constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
- constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
-
- return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
- }
-
- __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
- {
- constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
- make_naive_tensor_descriptor_packed(make_tuple(Number{},
- Number{},
- Number{},
- Number{},
- Number{},
- Number{}));
-
- return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
- }
-
- template
- __host__ __device__ static constexpr auto
- MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
- {
- const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
- c_m_n_grid_desc,
- make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
- make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
- make_tuple(Sequence<0>{}, Sequence<1>{}),
- make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
-
- return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
- }
-
- __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
- {
- return transform_tensor_descriptor(
- AK0MK1BlockDesc{},
- make_tuple(make_pass_through_transform(Number{}),
- make_unmerge_transform(
- make_tuple(Number{}, Number{}, Number{})),
- make_pass_through_transform(Number{})),
- make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
- make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
- }
-
- __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor()
- {
- return transform_tensor_descriptor(
- BK0NK1BlockDesc{},
- make_tuple(make_pass_through_transform(Number{}),
- make_unmerge_transform(
- make_tuple(Number{}, Number{}, Number{})),
- make_pass_through_transform(Number{})),
- make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
- make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
- }
-
- static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor();
- static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor();
-
- template
- __device__ void Run(const ABlockBuffer& a_block_buf,
- const BBlockBuffer& b_block_buf,
- CThreadBuffer& c_thread_buf) const
- {
- auto a_thread_buf = make_static_buffer(
- a_thread_desc_.GetElementSpaceSize());
- auto b_thread_buf = make_static_buffer(
- b_thread_desc_.GetElementSpaceSize());
-
- static_for<0, MRepeat, 1>{}([&](auto m0) {
- // read A
- a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
- make_tuple(I0, m0, I0, I0, I0),
- a_block_buf,
- a_thread_desc_,
- make_tuple(I0, I0, I0, I0, I0),
- a_thread_buf);
-
- static_for<0, NRepeat, 1>{}([&](auto n0) {
- // read B
- b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
- make_tuple(I0, n0, I0, I0, I0),
- b_block_buf,
- b_thread_desc_,
- make_tuple(I0, I0, I0, I0, I0),
- b_thread_buf);
-
- static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) {
- vector_type a_thread_vec;
- vector_type b_thread_vec;
-
- static_for<0, K1, 1>{}([&](auto i) {
- a_thread_vec.template AsType()(i) = a_thread_buf
- [Number{}];
- b_thread_vec.template AsType()(i) = b_thread_buf
- [Number{}];
- });
-
- using mfma_input_type =
- typename vector_type::type;
-
- constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0));
-
- xdlops_gemm.template Run(a_thread_vec.template AsType(),
- b_thread_vec.template AsType(),
- c_thread_buf.GetVector(Number{}));
- });
- });
- });
- }
-
- private:
- // A[K, M]
- static constexpr auto a_thread_desc_ =
- make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{}));
-
- // B[K, N]
- static constexpr auto b_thread_desc_ =
- make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{}));
-
- static constexpr auto c_thread_desc_ =
- make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}));
-
- using AThreadCopy = ThreadwiseTensorSliceTransfer_v4,
- Sequence<0, 1, 2, 3, 4>,
- 4,
- K1,
- K1>;
-
- using BThreadCopy = ThreadwiseTensorSliceTransfer_v4,
- Sequence<0, 1, 2, 3, 4>,
- 4,
- K1,
- K1>;
-
- AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
- BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
-};
-
-} // namespace ck
-#endif
diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
deleted file mode 100644
index 0214b713..00000000
--- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
+++ /dev/null
@@ -1,170 +0,0 @@
-#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
-#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
-
-#include "common_header.hpp"
-#include "tensor_descriptor.hpp"
-#include "tensor_descriptor_helper.hpp"
-#include "cluster_descriptor.hpp"
-#include "threadwise_tensor_slice_transfer.hpp"
-
-namespace ck {
-
-// this version does following things to avoid scratch memory issue
-// 1. Use StaticallyIndexedArray instead of C array for thread buffer
-// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
-// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
-template
-struct BlockwiseTensorSliceTransfer_v4
-{
- static constexpr index_t nDim = remove_reference_t::GetNumOfDimension();
-
- using Index = MultiIndex;
-
- __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
- const Index& src_block_slice_origin,
- const DstDesc& dst_desc,
- const Index& dst_block_slice_origin)
- : threadwise_transfer_(
- src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index())
-
- {
- static_assert(nDim == remove_reference_t>::GetNumOfDimension() &&
- nDim == remove_reference_t>::GetNumOfDimension() &&
- nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
- nDim == ThreadClusterLengths::Size() &&
- nDim == ThreadClusterArrangeOrder::Size() &&
- nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
- "wrong! nDim not consistent");
-
- static_assert(
- is_same{},
- "wrong! threads should be mapped to cover entire slicing window");
-
- static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
- "wrong! BlockSize too small");
-
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
- make_multi_index(get_thread_local_1d_id()));
-
- const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
-
- threadwise_transfer_.SetSrcSliceOrigin(src_desc,
- src_block_slice_origin + thread_data_idx_begin);
- threadwise_transfer_.SetDstSliceOrigin(dst_desc,
- dst_block_slice_origin + thread_data_idx_begin);
- }
- }
-
- template
- __device__ void
- RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
- {
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
- }
- }
-
- template
- __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
- {
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- threadwise_transfer_.RunRead(src_desc, src_buf);
- }
- }
-
- template
- __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
- {
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- threadwise_transfer_.RunWrite(dst_desc, dst_buf);
- }
- }
-
- __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
- {
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
- }
- }
-
- // SrcMoveSliceWindowStepHack to control index calculation move slice window
- template
- __device__ void
- MoveSrcSliceWindow(const SrcDesc& src_desc,
- const Index& step,
- const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
- {
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- threadwise_transfer_.MoveSrcSliceWindow(
- src_desc, step, src_move_slice_window_step_hack);
- }
- }
-
- __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
- {
- if(BlockSize == thread_cluster_desc_.GetElementSize() or
- get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
- {
- threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
- }
- }
-
- private:
- static constexpr auto thread_cluster_desc_ =
- make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
-
- using ThreadwiseTransfer =
- ThreadwiseTensorSliceTransfer_v3;
-
- ThreadwiseTransfer threadwise_transfer_;
-};
-
-} // namespace ck
-#endif
diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
deleted file mode 100644
index 2653dd43..00000000
--- a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp
+++ /dev/null
@@ -1,650 +0,0 @@
-#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
-#define CK_GRIDWISE_GEMM_V1R3_HPP
-
-#include "common_header.hpp"
-#include "multi_index_transform_helper.hpp"
-#include "tensor_descriptor.hpp"
-#include "tensor_descriptor_helper.hpp"
-#include "blockwise_gemm_dlops_v2r3.hpp"
-#include "blockwise_tensor_slice_transfer_v2.hpp"
-#include "threadwise_tensor_slice_transfer_v2.hpp"
-#include "threadwise_tensor_slice_set.hpp"
-
-namespace ck {
-
-#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
-template
-__global__ void
-#if CK_USE_LAUNCH_BOUNDS
- __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
-#endif
- kernel_gemm_dlops_v1r3(
- const FloatAB* __restrict__ p_a_grid,
- const FloatAB* __restrict__ p_b_grid,
- FloatC* __restrict__ p_c_grid,
- const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
- const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
- const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
- const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
-{
- constexpr index_t shared_block_size =
- GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
-
- __shared__ FloatAB p_shared_block[shared_block_size];
-
- GridwiseGemm::Run(p_a_grid,
- p_b_grid,
- p_c_grid,
- p_shared_block,
- a_k0_m0_m1_k1_grid_desc,
- b_k0_n0_n1_k1_grid_desc,
- c_m0_m10_m11_n0_n10_n11_grid_desc,
- c_blockid_to_m0_n0_block_cluster_adaptor,
- integral_constant{},
- integral_constant{});
-}
-#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
-// pass tensor descriptor by CONSTANT void pointer
-// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
-// non-modifiable parameter address space, so compiler can enable corresponding optimization
-template
-__global__ void
-#if CK_USE_LAUNCH_BOUNDS
- __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
-#endif
- kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
- const FloatAB* __restrict__ p_b_grid,
- FloatC* __restrict__ p_c_grid,
- const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
- const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
- const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
- const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
-{
- // first cast void CONSTANT void* to void*
- // second cast void* to Desc*
- // the copy constructor of tensor descriptor doesn't take address_space(4)
- const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
- const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
- const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
- *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
- const auto c_blockid_to_m0_n0_block_cluster_adaptor =
- *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
-
- constexpr index_t shared_block_size =
- GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
-
- __shared__ FloatAB p_shared_block[shared_block_size];
-
- GridwiseGemm::Run(p_a_grid,
- p_b_grid,
- p_c_grid,
- p_shared_block,
- a_k0_m0_m1_k1_grid_desc,
- b_k0_n0_n1_k1_grid_desc,
- c_m0_m10_m11_n0_n10_n11_grid_desc,
- c_blockid_to_m0_n0_block_cluster_adaptor,
- integral_constant{},
- integral_constant{});
-}
-#endif
-
-template
-struct GridwiseGemmDlops_km_kn_mn_v1r3
-{
- static constexpr auto I0 = Number<0>{};
- static constexpr auto I1 = Number<1>{};
- static constexpr auto I2 = Number<2>{};
- static constexpr auto I3 = Number<3>{};
-
- // K1 should be Number<...>
- static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
-
- __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
- {
- // TODO: change this. I think it needs multi-dimensional alignment
- constexpr auto max_lds_align = K1;
-
- // TODO: check alignment
- // A matrix in LDS memory, dst of blockwise copy
- constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
-
- // TODO: check alignment
- // B matrix in LDS memory, dst of blockwise copy
- constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
-
- // TODO: check alignment
- // LDS allocation for A and B: be careful of alignment
- constexpr auto a_block_aligned_space_size =
- math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
-
- constexpr auto b_block_aligned_space_size =
- math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
-
- return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
- }
-
- __host__ __device__ static constexpr bool
- CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
- const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
- const CMNGridDesc& c_m_n_grid_desc)
- {
- const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
- const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
- const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
-
- // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
-
- return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
- K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
- K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
- K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
- (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
- }
-
- __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
- {
- const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
-
- return grid_size;
- }
-
- __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
- {
- const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
-
- return has_main_k_block_loop;
- }
-
- __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
- {
- const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
-
- return has_double_tail_k_block_loop;
- }
-
- __host__ __device__ static constexpr auto
- MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
- {
- const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
- const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
-
- const auto M1 = Number{};
- const auto M0 = M / M1;
-
- const auto a_k0_m0_m1_k1_grid_desc =
- transform_tensor_descriptor(a_k0_m_k1_grid_desc,
- make_tuple(make_pass_through_transform(K0),
- make_unmerge_transform(make_tuple(M0, M1)),
- make_pass_through_transform(K1)),
- make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
- make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
-
- return a_k0_m0_m1_k1_grid_desc;
- }
-
- __host__ __device__ static constexpr auto
- MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
- {
- const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
- const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
-
- const auto N1 = Number{};
- const auto N0 = N / N1;
-
- const auto b_k0_n0_n1_k1_grid_desc =
- transform_tensor_descriptor(b_k0_n_k1_grid_desc,
- make_tuple(make_pass_through_transform(K0),
- make_unmerge_transform(make_tuple(N0, N1)),
- make_pass_through_transform(K1)),
- make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
- make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
-
- return b_k0_n0_n1_k1_grid_desc;
- }
-
- __host__ __device__ static constexpr auto
- MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
- {
- const auto M = c_m_n_grid_desc.GetLength(I0);
- const auto N = c_m_n_grid_desc.GetLength(I1);
-
- constexpr auto M1 = Number{};
- constexpr auto N1 = Number{};
-
- const auto M0 = M / M1;
- const auto N0 = N / N1;
-
- constexpr auto M11 =
- Number{};
- constexpr auto N11 =
- Number{};
-
- constexpr auto M10 = M1 / M11;
- constexpr auto N10 = N1 / N11;
-
- const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
- c_m_n_grid_desc,
- make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
- make_unmerge_transform(make_tuple(N0, N10, N11))),
- make_tuple(Sequence<0>{}, Sequence<1>{}),
- make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
-
- return c_m0_m10_m11_n0_n10_n11_grid_desc;
- }
-
- __host__ __device__ static constexpr auto
- MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
- {
- const auto M = c_m_n_grid_desc.GetLength(I0);
- const auto N = c_m_n_grid_desc.GetLength(I1);
-
- constexpr auto M1 = Number{};
- constexpr auto N1 = Number{};
-
- const auto M0 = M / M1;
- const auto N0 = N / N1;
-
- const auto c_blockid_to_m0_n0_block_cluster_adaptor =
- make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
- make_tuple(Sequence<0, 1>{}),
- make_tuple(Sequence<0>{}));
-
- return c_blockid_to_m0_n0_block_cluster_adaptor;
- }
-
- using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
- using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
- using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
- using CBlockIdToM0N0BlockClusterAdaptor =
- decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
-
- template
- __device__ static void
- Run(const FloatAB* __restrict__ p_a_grid,
- const FloatAB* __restrict__ p_b_grid,
- FloatC* __restrict__ p_c_grid,
- FloatAB* __restrict__ p_shared_block,
- const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
- const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
- const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
- const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
- integral_constant,
- integral_constant)
- {
- const auto a_global_buf = make_dynamic_buffer(
- p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
- const auto b_global_buf = make_dynamic_buffer(
- p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
- auto c_grid_buf = make_dynamic_buffer(
- p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
-
- // divide block work by [M, N]
- const auto c_m0_n0_block_cluster_idx =
- c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
- make_multi_index(get_block_1d_id()));
-
- // HACK: this force index data into SGPR
- const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
- const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
-
- // TODO: change this. I think it needs multi-dimensional alignment
- constexpr auto max_lds_align = K1;
-
- // TODO: check alignment
- // A matrix in LDS memory, dst of blockwise copy
- // be careful of LDS alignment
- constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, I1, Number{}, K1), max_lds_align);
-
- // TODO: check alignment
- // B matrix in LDS memory, dst of blockwise copy
- // be careful of LDS alignment
- constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, I1, Number{}, K1), max_lds_align);
-
- // TODO: check alignment
- // A matrix in LDS memory, for blockwise GEMM
- constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
-
- // TODO: check alignment
- // B matrix in LDS memory, for blockwise GEMM
- constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
-
- static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
- a_k0_m_k1_block_desc.GetElementSpaceSize() &&
- b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
- b_k0_n_k1_block_desc.GetElementSpaceSize() &&
- "wrong!");
-
- // A matrix blockwise copy
- auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
- BlockSize,
- InMemoryDataOperationEnum_t::Set,
- Sequence,
- ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
- ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
- ABlockTransferThreadClusterArrangeOrder,
- FloatAB,
- FloatAB,
- decltype(a_k0_m0_m1_k1_grid_desc),
- decltype(a_k0_m0_m1_k1_block_desc),
- ABlockTransferSrcAccessOrder,
- Sequence<0, 1, 2, 3>,
- ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
- ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
- ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
- Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
- false,
- true>(a_k0_m0_m1_k1_grid_desc,
- make_multi_index(0, im0, 0, 0),
- a_k0_m0_m1_k1_block_desc,
- make_multi_index(0, 0, 0, 0));
-
- // B matrix blockwise copy
- auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
- BlockSize,
- InMemoryDataOperationEnum_t::Set,
- Sequence,
- BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
- BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
- BBlockTransferThreadClusterArrangeOrder,
- FloatAB,
- FloatAB,
- decltype(b_k0_n0_n1_k1_grid_desc),
- decltype(b_k0_n0_n1_k1_block_desc),
- BBlockTransferSrcAccessOrder,
- Sequence<0, 1, 2, 3>,
- BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
- BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
- BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
- Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
- false,
- true>(b_k0_n0_n1_k1_grid_desc,
- make_multi_index(0, in0, 0, 0),
- b_k0_n0_n1_k1_block_desc,
- make_multi_index(0, 0, 0, 0));
-
- // GEMM definition
- // c_mtx += transpose(a_mtx) * b_mtx
- // a_mtx[KPerBlock, MPerBlockM1] is in LDS
- // b_mtx[KPerBlocl, NPerBlockN1] is in LDS
- // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
- // register
- const auto blockwise_gemm =
- BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
- BlockSize,
- FloatAB,
- FloatAB,
- FloatAcc,
- decltype(a_k0_m_k1_block_desc),
- decltype(b_k0_n_k1_block_desc),
- M1PerThreadM111,
- N1PerThreadN111,
- KPerThread,
- M11N11ThreadClusterM110Xs,
- M11N11ThreadClusterN110Xs,
- M1PerThreadM111,
- N1PerThreadN111>{};
-
- constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
- decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
-
- constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
- sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
-
- // LDS allocation for A and B: be careful of alignment
- constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
- a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
-
- constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
- b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
-
- FloatAB* p_a_block_double = p_shared_block;
- FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
-
- // register allocation for output
- auto c_thread_buf = make_static_buffer(
- c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
-
- ThreadwiseTensorSliceSet_v1{}
- .Run(c_m10_m11_n10_n11_thread_desc,
- make_tuple(I0, I0, I0, I0),
- c_thread_buf,
- FloatAcc{0});
-
- constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
- constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
-
- auto a_block_even_buf = make_dynamic_buffer(
- p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
- auto b_block_even_buf = make_dynamic_buffer(
- p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
-
- auto a_block_odd_buf = make_dynamic_buffer(
- p_a_block_double + a_block_aligned_space_size,
- a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
- auto b_block_odd_buf = make_dynamic_buffer(
- p_b_block_double + b_block_aligned_space_size,
- b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
-
- // LDS double buffer: preload data into LDS
- {
- a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
- b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
-
- a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
- b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
- }
-
- if constexpr(HasMainKBlockLoop)
- {
- const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
-
- index_t k_block_data_begin = 0;
-
- // LDS double buffer: main body
- // use Do-While loop instead of For loop to simplify control flow
- do
- {
- // even iteration
- a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
- a_block_slice_copy_step,
- AGridMoveSliceWindowStepHacks{});
- b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
- b_block_slice_copy_step,
- BGridMoveSliceWindowStepHacks{});
-
- __syncthreads();
-
- // LDS doubel buffer: load next data from device mem
- a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
- b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
-
- // LDS double buffer: GEMM on current data
- blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
- a_block_even_buf,
- b_block_even_buf,
- c_thread_buf);
-
- // LDS double buffer: store next data to LDS
- a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
- b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
-
- // odd iteration
- a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
- a_block_slice_copy_step,
- AGridMoveSliceWindowStepHacks{});
- b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
- b_block_slice_copy_step,
- BGridMoveSliceWindowStepHacks{});
-
- __syncthreads();
-
- // LDS doubel buffer: load next data from device mem
- a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
- b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
-
- // LDS double buffer: GEMM on current data
- blockwise_gemm.Run(
- c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
-
- // LDS double buffer: store next data to LDS
- a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
- b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
-
- k_block_data_begin += 2 * KPerBlock;
- } while(k_block_data_begin < K0 - 2 * KPerBlock);
- }
-
- // LDS double buffer: tail
- if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
- {
- a_blockwise_copy.MoveSrcSliceWindow(
- a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
- b_blockwise_copy.MoveSrcSliceWindow(
- b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
-
- __syncthreads();
-
- // LDS double buffer: load last data from device mem
- a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
- b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
-
- // LDS double buffer: GEMM on 2nd-last data
- blockwise_gemm.Run(
- c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
-
- // LDS double buffer: store last data to LDS
- a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
- b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
-
- __syncthreads();
-
- // LDS double buffer: GEMM on last data
- blockwise_gemm.Run(
- c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
- }
- else // if has 1 iteration left
- {
- __syncthreads();
-
- // LDS double buffer: GEMM on last data
- blockwise_gemm.Run(
- c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
- }
-
- // output: register to global memory
- {
- constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
- make_naive_tensor_descriptor_packed(
- make_tuple(I1,
- Number{},
- Number{},
- I1,
- Number{},
- Number{}));
-
- const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
- blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
- get_thread_local_1d_id());
-
- ThreadwiseTensorSliceTransfer_v1r3<
- FloatAcc,
- FloatC,
- decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
- decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
- Sequence<1,
- c_m10_m11_n10_n11_thread_tensor_lengths[I0],
- c_m10_m11_n10_n11_thread_tensor_lengths[I1],
- 1,
- c_m10_m11_n10_n11_thread_tensor_lengths[I2],
- c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
- CThreadTransferSrcDstAccessOrder,
- CThreadTransferSrcDstVectorDim,
- CThreadTransferDstScalarPerVector,
- CGlobalMemoryDataOperation,
- 1,
- true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
- make_multi_index(im0,
- c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
- c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
- in0,
- c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
- c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
- .Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
- make_tuple(I0, I0, I0, I0, I0, I0),
- c_thread_buf,
- c_m0_m10_m11_n0_n10_n11_grid_desc,
- c_grid_buf,
- CGridStepHacks{});
- }
- }
-};
-
-} // namespace ck
-#endif
diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
deleted file mode 100644
index 86e047c9..00000000
--- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+++ /dev/null
@@ -1,639 +0,0 @@
-#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
-#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
-
-#include "common_header.hpp"
-#include "multi_index_transform_helper.hpp"
-#include "tensor_descriptor.hpp"
-#include "tensor_descriptor_helper.hpp"
-#include "blockwise_gemm_xdlops.hpp"
-#include "blockwise_tensor_slice_transfer.hpp"
-#include "threadwise_tensor_slice_transfer.hpp"
-#include "threadwise_tensor_slice_set.hpp"
-
-namespace ck {
-
-#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
-template
-__global__ void
-#if CK_USE_LAUNCH_BOUNDS
- __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
-#endif
- kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
- const FloatAB* __restrict__ p_b_grid,
- FloatC* __restrict__ p_c_grid,
- const AK0MK1GridDesc a_k0_m_k1_grid_desc,
- const BK0NK1GridDesc b_k0_n_k1_grid_desc,
- const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
- const CBlockClusterAdaptor c_block_cluster_adaptor)
-{
- constexpr index_t shared_block_size =
- GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
-
- __shared__ FloatAB p_shared_block[shared_block_size];
-
- GridwiseGemm::template Run(p_a_grid,
- p_b_grid,
- p_c_grid,
- p_shared_block,
- a_k0_m_k1_grid_desc,
- b_k0_n_k1_grid_desc,
- c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
- c_block_cluster_adaptor);
-}
-#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
-template
-__global__ void
-#if CK_USE_LAUNCH_BOUNDS
- __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
-#endif
- kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
- const FloatAB* __restrict__ p_b_grid,
- FloatC* __restrict__ p_c_grid,
- const void CONSTANT* p_a_k0_m_k1_grid_desc,
- const void CONSTANT* p_b_k0_n_k1_grid_desc,
- const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
- const void CONSTANT* p_c_block_cluster_adaptor)
-{
- constexpr index_t shared_block_size =
- GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
-
- const auto a_k0_m_k1_grid_desc = *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
- const auto b_k0_n_k1_grid_desc = *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
- const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
- *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
- const auto c_block_cluster_adaptor = *reinterpret_cast(
- cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
-
- __shared__ FloatAB p_shared_block[shared_block_size];
-
- GridwiseGemm::template Run(p_a_grid,
- p_b_grid,
- p_c_grid,
- p_shared_block,
- a_k0_m_k1_grid_desc,
- b_k0_n_k1_grid_desc,
- c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
- c_block_cluster_adaptor);
-}
-#endif
-
-template
-struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
-{
- static constexpr auto I0 = Number<0>{};
- static constexpr auto I1 = Number<1>{};
- static constexpr auto I2 = Number<2>{};
- static constexpr auto I3 = Number<3>{};
- static constexpr auto I4 = Number<4>{};
- static constexpr auto I5 = Number<5>{};
- static constexpr auto I6 = Number<6>{};
- static constexpr auto I7 = Number<7>{};
-
- // K1 should be Number<...>
- static constexpr auto K1 = Number{};
-
- __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
- {
- constexpr auto max_lds_align = K1;
-
- // A matrix in LDS memory, dst of blockwise copy
- constexpr auto a_k0_m_k1_block_desc = [&]() {
- if constexpr(ABlockLdsExtraM)
- {
- return make_naive_tensor_descriptor(
- make_tuple(Number{}, Number{}, K1),
- make_tuple(Number{} * K1, K1, I1));
- }
- else
- {
- return make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
- }
- }();
-
- // B matrix in LDS memory, dst of blockwise copy
- constexpr auto b_k0_n_k1_block_desc = [&]() {
- if constexpr(BBlockLdsExtraN)
- {
- return make_naive_tensor_descriptor(
- make_tuple(Number{}, Number{}, K1),
- make_tuple(Number{} * K1, K1, I1));
- }
- else
- {
- return make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
- }
- }();
-
- // LDS allocation for A and B: be careful of alignment
- constexpr auto a_block_space_size =
- math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
-
- constexpr auto b_block_space_size =
- math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
-
- return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
- }
-
- // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
- __host__ __device__ static constexpr bool
- CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
- const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
- const CMNGridDesc& c_m_n_grid_desc,
- index_t M01,
- index_t N01)
- {
- static_assert(is_known_at_compile_time>::value,
- "wrong! K1 need to be known at compile-time");
-
- static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
- (NPerBlock % (NRepeat * NPerXDL)) == 0,
- "Invalid tuning param!");
-
- const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
- const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
- const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
-
- if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
- K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
- K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
- return false;
-
- if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
- return false;
-
- // check M01, N01
- constexpr auto M1 = Number{};
- constexpr auto N1 = Number{};
-
- const auto M0 = M / M1;
- const auto N0 = N / N1;
-
- if(!(M0 % M01 == 0 && N0 % N01 == 0))
- return false;
-
- // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
- return true;
- }
-
- __host__ __device__ static constexpr index_t
- CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
- {
- const auto M = c_m_n_grid_desc.GetLength(I0);
- const auto N = c_m_n_grid_desc.GetLength(I1);
-
- const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
-
- return grid_size;
- }
-
- __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
- {
- const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
-
- return has_main_k0_block_loop;
- }
-
- __host__ __device__ static constexpr auto
- MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
- {
- constexpr auto max_lds_align = K1;
-
- // A matrix in LDS memory, dst of blockwise copy
- constexpr auto a_k0_m_k1_block_desc = [&]() {
- if constexpr(ABlockLdsExtraM)
- {
- return make_naive_tensor_descriptor(
- make_tuple(Number{}, Number{}, K1),
- make_tuple(Number{} * K1, K1, I1));
- }
- else
- {
- return make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
- }
- }();
-
- // B matrix in LDS memory, dst of blockwise copy
- constexpr auto b_k0_n_k1_block_desc = [&]() {
- if constexpr(BBlockLdsExtraN)
- {
- return make_naive_tensor_descriptor(
- make_tuple(Number{}, Number{}, K1),
- make_tuple(Number{} * K1, K1, I1));
- }
- else
- {
- return make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
- }
- }();
-
- using BlockwiseGemm =
- BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1;
-
- return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
- }
-
- // return block_id to C matrix tile idx (m0, n0) mapping
- __host__ __device__ static constexpr auto
- MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01)
- {
- const auto M = c_m_n_grid_desc.GetLength(I0);
- const auto N = c_m_n_grid_desc.GetLength(I1);
-
- constexpr auto M1 = Number{};
- constexpr auto N1 = Number{};
-
- const auto M0 = M / M1;
- const auto N0 = N / N1;
-
- const auto M00 = M0 / M01;
- const auto N00 = N0 / N01;
-
- const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
- make_single_stage_tensor_adaptor(
- make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
- make_unmerge_transform(make_tuple(N00, N01))),
- make_tuple(Sequence<0>{}, Sequence<1>{}),
- make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
-
- const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
- make_single_stage_tensor_adaptor(
- make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
- make_tuple(Sequence<0, 1, 2, 3>{}),
- make_tuple(Sequence<0>{}));
-
- const auto c_blockid_to_m0_n0_block_cluster_adaptor =
- chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
- c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
-
- return c_blockid_to_m0_n0_block_cluster_adaptor;
- }
-
- using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
- using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
-
- template
- __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
- const FloatAB* __restrict__ p_b_grid,
- FloatC* __restrict__ p_c_grid,
- FloatAB* __restrict__ p_shared_block,
- const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
- const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
- const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
- const CBlockClusterAdaptor& c_block_cluster_adaptor)
- {
- const auto a_grid_buf = make_dynamic_buffer(
- p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
- const auto b_grid_buf = make_dynamic_buffer(
- p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
- auto c_grid_buf = make_dynamic_buffer(
- p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
-
- const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
-
- // divide block work by [M, N]
- const auto block_work_idx =
- c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
-
- // HACK: this force m/n_block_data_idx_on_grid into SGPR
- const index_t m_block_data_idx_on_grid =
- __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
-
- const index_t n_block_data_idx_on_grid =
- __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
-
- // lds max alignment
- constexpr auto max_lds_align = K1;
-
- // A matrix in LDS memory, dst of blockwise copy
- constexpr auto a_k0_m_k1_block_desc = [&]() {
- if constexpr(ABlockLdsExtraM)
- {
- return make_naive_tensor_descriptor(
- make_tuple(Number{}, Number{}, K1),
- make_tuple(Number{} * K1, K1, I1));
- }
- else
- {
- return make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
- }
- }();
-
- // B matrix in LDS memory, dst of blockwise copy
- constexpr auto b_k0_n_k1_block_desc = [&]() {
- if constexpr(BBlockLdsExtraN)
- {
- return make_naive_tensor_descriptor(
- make_tuple(Number{}, Number{}, K1),
- make_tuple(Number{} * K1, K1, I1));
- }
- else
- {
- return make_naive_tensor_descriptor_aligned(
- make_tuple(Number{}, Number{}, K1), max_lds_align);
- }
- }();
-
- // A matrix blockwise copy
- auto a_blockwise_copy =
- BlockwiseTensorSliceTransfer_v4,
- ABlockTransferThreadSliceLengths_K0_M_K1,
- ABlockTransferThreadClusterLengths_K0_M_K1,
- ABlockTransferThreadClusterArrangeOrder,
- FloatAB,
- FloatAB,
- decltype(a_k0_m_k1_grid_desc),
- decltype(a_k0_m_k1_block_desc),
- ABlockTransferSrcAccessOrder,
- Sequence<1, 0, 2>,
- ABlockTransferSrcVectorDim,
- 2,
- ABlockTransferSrcScalarPerVector,
- ABlockTransferDstScalarPerVector_K1,
- 1,
- 1,
- AThreadTransferSrcResetCoordinateAfterRun,
- true>(a_k0_m_k1_grid_desc,
- make_multi_index(0, m_block_data_idx_on_grid, 0),
- a_k0_m_k1_block_desc,
- make_multi_index(0, 0, 0));
-
- // B matrix blockwise copy
- auto b_blockwise_copy =
- BlockwiseTensorSliceTransfer_v4,
- BBlockTransferThreadSliceLengths_K0_N_K1,
- BBlockTransferThreadClusterLengths_K0_N_K1,
- BBlockTransferThreadClusterArrangeOrder,
- FloatAB,
- FloatAB,
- decltype(b_k0_n_k1_grid_desc),
- decltype(b_k0_n_k1_block_desc),
- BBlockTransferSrcAccessOrder,
- Sequence<1, 0, 2>,
- BBlockTransferSrcVectorDim,
- 2,
- BBlockTransferSrcScalarPerVector,
- BBlockTransferDstScalarPerVector_K1,
- 1,
- 1,
- BThreadTransferSrcResetCoordinateAfterRun,
- true>(b_k0_n_k1_grid_desc,
- make_multi_index(0, n_block_data_idx_on_grid, 0),
- b_k0_n_k1_block_desc,
- make_multi_index(0, 0, 0));
-
- // GEMM definition
- // c_mtx += transpose(a_mtx) * b_mtx
- // a_mtx[K0PerBlock, MPerBlock] is in LDS
- // b_mtx[K0PerBlock, NPerBlock] is in LDS
- // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
- // register
- // sanity check
-
- auto blockwise_gemm =
- BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{};
-
- auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
-
- // LDS allocation for A and B: be careful of alignment
- constexpr auto a_block_space_size =
- math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
-
- FloatAB* p_a_block = p_shared_block;
- FloatAB* p_b_block = p_shared_block + a_block_space_size;
-
- constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
- constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
-
- // hack to control index calculation when iterating over A and B matrix for threadwise copy
- constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
- constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
-
- // hack to control index calculation when move slice window for A and B matrix for
- // threadwise copy
- constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
- constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
-
- auto a_block_buf = make_dynamic_buffer(
- p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
- auto b_block_buf = make_dynamic_buffer(
- p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
-
- // preload data into LDS
- {
- a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
- b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
-
- a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
- b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
- }
-
- // main body
- index_t k0_block_data_begin = 0;
-
- if constexpr(HasMainKBlockLoop)
- {
- do
- {
- a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
- a_block_slice_copy_step,
- a_k0_m_k1_grid_move_slice_window_step_hack);
- b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
- b_block_slice_copy_step,
- b_k0_n_k1_grid_move_slice_window_step_hack);
-
- a_blockwise_copy.RunRead(
- a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
-
- block_sync_lds();
-
- b_blockwise_copy.RunRead(
- b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
-
- blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
-
- block_sync_lds();
-
- a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
- b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
-
- k0_block_data_begin += K0PerBlock;
- } while(k0_block_data_begin < (K0 - K0PerBlock));
- }
-
- // tail
- {
- block_sync_lds();
-
- blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
- }
-
- // output: register to global memory
- {
- constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
- blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
-
- constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
- constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
- constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
- constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
- constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
- constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
- constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
- constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
-
- constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
- make_naive_tensor_descriptor_packed(make_tuple(
- Number