From fa6b7b5071c880bbcf33b19cb362713a8d445d96 Mon Sep 17 00:00:00 2001 From: JBFYS-XD Date: Sun, 24 Aug 2025 11:28:16 +0800 Subject: [PATCH 1/5] JSBFYS_XD: ALL PASSED --- src/kernels.cu | 229 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 224 insertions(+), 5 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index 8df8130..0d26999 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -1,4 +1,5 @@ #include +#include #include "../tester/utils.h" @@ -15,10 +16,122 @@ * @note For invalid cases, return T(-100). * @note Handles device memory management (allocate/copy/free) internally. Errors should be thrown. */ + +template +__global__ void kTopSort1(T* input, int n) { + extern __shared__ uint8_t shared_mem[]; // 单一符号,无类型冲突 + + int blockSize = blockDim.x; + // 手动划分:前 512 个 T 为 smem,后 512 个为 tmp + T* smem = reinterpret_cast(shared_mem); + T* tmp = smem + (blockSize << 1); // 指向后半部分 + + int tid = threadIdx.x; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int posb = tid << 1, posg = idx << 1; + + smem[posb] = posg < n ? input[posg] : 1e9; + smem[posb | 1] = (posg | 1) < n ? input[posg | 1] : 1e9; + __syncthreads(); + for (int flag = 1; flag <= blockSize; flag <<= 1) { + int start = posb; + int mid = min(start + flag, blockSize << 1); + int end = min(start + (flag << 1), blockSize << 1); + if ((tid % flag) == 0) { + int l = start, r = mid, k = start; + while (l < mid && r < end) { + if (smem[l] < smem[r]) + tmp[k ++] = smem[l ++]; + else + tmp[k ++] = smem[r ++]; + } + + while (l < mid) tmp[k ++] = smem[l ++]; + while (r < end) tmp[k ++] = smem[r ++]; + + } + __syncthreads(); + + smem[start] = tmp[start]; + smem[start + 1] = tmp[start + 1]; + __syncthreads(); + } + + if (posg < n) input[posg] = smem[posb]; + if ((posg | 1) < n) input[posg | 1] = smem[posb | 1]; +} + +template +__global__ void kTopSort2(T* input, int n, T* gtmp, int flag) { + int tid = threadIdx.x; + int bid = blockIdx.x; + + int start = bid * flag * 2; + int mid = min(start + flag, n); + int end = min(start + (flag << 1), n); + + if (tid == 0) { + int l = start, r = mid, k = start; + while (l < mid && r < end) { + if (input[l] < input[r]) + gtmp[k ++] = input[l ++]; + else + gtmp[k ++] = input[r ++]; + } + + while (l < mid) gtmp[k ++] = input[l ++]; + while (r < end) gtmp[k ++] = input[r ++]; + } + + __syncthreads(); + for (int i = start + tid; i < end; i += blockDim.x) { + if (i < n) + input[i] = gtmp[i]; + } + +} + +template +void kTopSort_Work(T* input, T* gtmp, size_t n) { + int blockSize = 256; + + int flag = 1; + int gridSize = ((n + flag - 1) / flag + blockSize - 1) / blockSize; + size_t shared_mem = blockSize * sizeof(T) * 4; + kTopSort1<<>>(input, n); + CUDA_CHECK(cudaDeviceSynchronize()); + + for (flag = 512; flag < n; flag <<= 1) { + gridSize = (n + flag - 1) / flag; + kTopSort2<<>>(input, n, gtmp, flag); + } + CUDA_CHECK(cudaDeviceSynchronize()); +} + template T kthLargest(const std::vector& h_input, size_t k) { // TODO: Implement the kthLargest function - return T(-1000); + + size_t n = h_input.size(); + if (k < 1 || k > n) return T(-100); + + size_t size_T = sizeof(T); + size_t size_arr = n * size_T; + + T *d_input; + T* gtmp; + CUDA_CHECK(cudaMalloc(&d_input, size_arr)); + CUDA_CHECK(cudaMalloc(>mp, n * sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), size_arr, cudaMemcpyHostToDevice)); + + kTopSort_Work(d_input, gtmp, n); + + T result; + CUDA_CHECK(cudaMemcpy(&result, d_input + (n - k), size_T, cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(gtmp)); + CUDA_CHECK(cudaFree(d_input)); + return result; } /** @@ -37,11 +150,117 @@ T kthLargest(const std::vector& h_input, size_t k) { * @param[in] head_dim Dimension size of each attention head * @param[in] is_causal Whether to apply causal masking */ + template -void flashAttention(const std::vector& h_q, const std::vector& h_k, - const std::vector& h_v, std::vector& h_o, - int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { +__global__ void flashAttentionKernel( + const T* query, const T* key, + const T* value, T* output, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal +) { + + + int batch_id = blockIdx.x; + int heads_id = blockIdx.y; + int kv_id = heads_id * kv_heads / query_heads; + + if (batch_id >= batch_size || heads_id >= query_heads) return; + + extern __shared__ uint8_t shared_mem[]; + T* scores = reinterpret_cast(shared_mem); + + for (int tgt = 0; tgt < target_seq_len; tgt ++) { + for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { + if (is_causal && tgt < src) { + scores[src] = -1e9f; + } else { + float sum = 0.; + for (int dim = 0; dim < head_dim; dim ++) { + // query[batch_id][tgt][heads_id][dim] + int qid = dim + head_dim * (heads_id + query_heads * (tgt + target_seq_len * batch_id)); + // key[batch_id][src][kv_id][dim] + int kid = dim + head_dim * (kv_id + kv_heads * (src + src_seq_len * batch_id)); + sum += query[qid] * key[kid]; + } + scores[src] = sum / sqrtf(float(head_dim)); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + float mx = -1e9; + for (int i = 0; i < src_seq_len; i ++) + mx = fmaxf(mx, scores[i]); + float sum = 0.; + for (int i = 0; i < src_seq_len; i ++) { + scores[i] = expf(scores[i] - mx); + sum += scores[i]; + } + float val = 1. / sum; + for (int i = 0; i < src_seq_len; i ++) + scores[i] *= val; + } + __syncthreads(); + + for (int dim = threadIdx.x; dim < head_dim; dim += blockDim.x) { + float sum = 0.; + for (int src = 0; src < src_seq_len; src ++) { + // value[batch_id][src][kv_id][dim] + int vidx = dim + head_dim * (kv_id + kv_heads * (src + src_seq_len * batch_id)); + sum += scores[src] * value[vidx]; + } + // output[batch_id][tgt][heads_id][dim] + int oidx = dim + head_dim * (heads_id + query_heads * (tgt + target_seq_len * batch_id)); + output[oidx] = sum; + } + __syncthreads(); + } +} + + +template +void flashAttention( + const std::vector& h_q, const std::vector& h_k, + const std::vector& h_v, std::vector& h_o, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal +) { + + h_o.resize(batch_size * target_seq_len * query_heads * head_dim); + + T* d_q, *d_k, *d_v, *d_o; + size_t size_q = h_q.size() * sizeof(T); + size_t size_k = h_k.size() * sizeof(T); + size_t size_v = h_v.size() * sizeof(T); + size_t size_o = h_o.size() * sizeof(T); + + CUDA_CHECK(cudaMalloc(&d_q, size_q)); + CUDA_CHECK(cudaMalloc(&d_k, size_k)); + CUDA_CHECK(cudaMalloc(&d_v, size_v)); + CUDA_CHECK(cudaMalloc(&d_o, size_o)); + + CUDA_CHECK(cudaMemcpy(d_q, h_q.data(), size_q, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_k, h_k.data(), size_k, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_v, h_v.data(), size_v, cudaMemcpyHostToDevice)); + + dim3 gridSize(batch_size, query_heads); + int blockSize = 256; + size_t shared_mem = src_seq_len * sizeof(T); + flashAttentionKernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, is_causal + ); + + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(h_o.data(), d_o, size_o, cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(d_q)); + CUDA_CHECK(cudaFree(d_k)); + CUDA_CHECK(cudaFree(d_v)); + CUDA_CHECK(cudaFree(d_o)); + } // ********************************************************************* From 31282a64a26c100cc308c2c62cf405204bb146d1 Mon Sep 17 00:00:00 2001 From: JBFYS-XD Date: Sun, 24 Aug 2025 13:45:40 +0800 Subject: [PATCH 2/5] JBFYS_XD: Optimized reduction for softmax --- src/kernels.cu | 86 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 13 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index 0d26999..731d270 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -151,6 +151,62 @@ T kthLargest(const std::vector& h_input, size_t k) { * @param[in] is_causal Whether to apply causal masking */ +template +__device__ T block_reduce_max(T* smem, T val) { + int tid = threadIdx.x; + int lane = tid % 32; + int warp_id = tid / 32; + + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = fmaxf(val, other); + } + + if (lane == 0) + smem[warp_id] = val; + __syncthreads(); + + if (warp_id == 0) { + val = lane >= 8 ? -1e9 : smem[lane]; + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = fmaxf(val, other); + } + smem[0] = val; + } + __syncthreads(); + + return smem[0]; +} + +template +__device__ T block_reduce_sum(T* smem, T val) { + int tid = threadIdx.x; + int lane = tid % 32; + int warp_id = tid / 32; + + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = val + other; + } + + if (lane == 0) + smem[warp_id] = val; + __syncthreads(); + + if (warp_id == 0) { + val = lane >= 8 ? 0 : smem[lane]; + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = val + other; + } + smem[0] = val; + } + __syncthreads(); + + return smem[0]; +} + template __global__ void flashAttentionKernel( const T* query, const T* key, @@ -168,8 +224,10 @@ __global__ void flashAttentionKernel( extern __shared__ uint8_t shared_mem[]; T* scores = reinterpret_cast(shared_mem); + T* smem = scores + src_seq_len; for (int tgt = 0; tgt < target_seq_len; tgt ++) { + float mx = -1e9; for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { if (is_causal && tgt < src) { scores[src] = -1e9f; @@ -183,22 +241,24 @@ __global__ void flashAttentionKernel( sum += query[qid] * key[kid]; } scores[src] = sum / sqrtf(float(head_dim)); + mx = fmaxf(mx, scores[src]); } } __syncthreads(); - if (threadIdx.x == 0) { - float mx = -1e9; - for (int i = 0; i < src_seq_len; i ++) - mx = fmaxf(mx, scores[i]); - float sum = 0.; - for (int i = 0; i < src_seq_len; i ++) { - scores[i] = expf(scores[i] - mx); - sum += scores[i]; - } - float val = 1. / sum; - for (int i = 0; i < src_seq_len; i ++) - scores[i] *= val; + mx = block_reduce_max(smem, mx); + + T sum = 0.; + for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { + scores[src] = expf(scores[src] - mx); + sum += scores[src]; + } + __syncthreads(); + + sum = block_reduce_sum(smem, sum); + + for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { + scores[src] = scores[src] / (sum + 1e-8f); } __syncthreads(); @@ -245,7 +305,7 @@ void flashAttention( dim3 gridSize(batch_size, query_heads); int blockSize = 256; - size_t shared_mem = src_seq_len * sizeof(T); + size_t shared_mem = src_seq_len * sizeof(T) + 8 * sizeof(T); flashAttentionKernel<<>>( d_q, d_k, d_v, d_o, batch_size, target_seq_len, src_seq_len, From 2294e9015e19a4f8a91c2e755daa484c23d8b750 Mon Sep 17 00:00:00 2001 From: JBFYS-XD Date: Sun, 24 Aug 2025 19:56:21 +0800 Subject: [PATCH 3/5] Bitonic sort replaces merge sort --- src/kernels.cu | 140 +++++++++++++++++-------------------------------- 1 file changed, 47 insertions(+), 93 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index 731d270..966025a 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -18,119 +18,73 @@ */ template -__global__ void kTopSort1(T* input, int n) { - extern __shared__ uint8_t shared_mem[]; // 单一符号,无类型冲突 - - int blockSize = blockDim.x; - // 手动划分:前 512 个 T 为 smem,后 512 个为 tmp - T* smem = reinterpret_cast(shared_mem); - T* tmp = smem + (blockSize << 1); // 指向后半部分 - - int tid = threadIdx.x; - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int posb = tid << 1, posg = idx << 1; - - smem[posb] = posg < n ? input[posg] : 1e9; - smem[posb | 1] = (posg | 1) < n ? input[posg | 1] : 1e9; - __syncthreads(); - for (int flag = 1; flag <= blockSize; flag <<= 1) { - int start = posb; - int mid = min(start + flag, blockSize << 1); - int end = min(start + (flag << 1), blockSize << 1); - if ((tid % flag) == 0) { - int l = start, r = mid, k = start; - while (l < mid && r < end) { - if (smem[l] < smem[r]) - tmp[k ++] = smem[l ++]; - else - tmp[k ++] = smem[r ++]; - } - - while (l < mid) tmp[k ++] = smem[l ++]; - while (r < end) tmp[k ++] = smem[r ++]; - +__global__ void bitonic_sort_kernel(T* input, int n, int desc, int k, int j) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = x ^ j; + if (x >= n || x > y) return; + T valx = input[x]; + T valy = input[y]; + T val; + if ((valx > valy) == (desc ^ ((x & k) != 0))) { + if (x > y) { + input[x] = valy; + input[y] = valx; + } + } else { + if (x < y) { + input[x] = valy; + input[y] = valx; } - __syncthreads(); - - smem[start] = tmp[start]; - smem[start + 1] = tmp[start + 1]; - __syncthreads(); } - - if (posg < n) input[posg] = smem[posb]; - if ((posg | 1) < n) input[posg | 1] = smem[posb | 1]; } template -__global__ void kTopSort2(T* input, int n, T* gtmp, int flag) { - int tid = threadIdx.x; - int bid = blockIdx.x; - - int start = bid * flag * 2; - int mid = min(start + flag, n); - int end = min(start + (flag << 1), n); - - if (tid == 0) { - int l = start, r = mid, k = start; - while (l < mid && r < end) { - if (input[l] < input[r]) - gtmp[k ++] = input[l ++]; - else - gtmp[k ++] = input[r ++]; - } - - while (l < mid) gtmp[k ++] = input[l ++]; - while (r < end) gtmp[k ++] = input[r ++]; +__global__ void init_input(T* input, int n, int mx) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + if (x < n && x >= mx) { + input[x] = T(-1e9); } - - __syncthreads(); - for (int i = start + tid; i < end; i += blockDim.x) { - if (i < n) - input[i] = gtmp[i]; - } - } template -void kTopSort_Work(T* input, T* gtmp, size_t n) { - int blockSize = 256; +T kthLargest(const std::vector& h_input, size_t k) { + // TODO: Implement the kthLargest function + int n = h_input.size(); + + if (k < 1 || k > n) + return T(-100); + + while (__builtin_popcount(n) != 1) n += (n & (-n)); - int flag = 1; - int gridSize = ((n + flag - 1) / flag + blockSize - 1) / blockSize; - size_t shared_mem = blockSize * sizeof(T) * 4; - kTopSort1<<>>(input, n); - CUDA_CHECK(cudaDeviceSynchronize()); + T *d_input; - for (flag = 512; flag < n; flag <<= 1) { - gridSize = (n + flag - 1) / flag; - kTopSort2<<>>(input, n, gtmp, flag); - } - CUDA_CHECK(cudaDeviceSynchronize()); -} + size_t size_input = h_input.size() * sizeof(T); + size_t size_malloc = n * sizeof(T); -template -T kthLargest(const std::vector& h_input, size_t k) { - // TODO: Implement the kthLargest function + CUDA_CHECK(cudaMalloc(&d_input, size_malloc)); - size_t n = h_input.size(); - if (k < 1 || k > n) return T(-100); + CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), size_input, cudaMemcpyHostToDevice)); - size_t size_T = sizeof(T); - size_t size_arr = n * size_T; - T *d_input; - T* gtmp; - CUDA_CHECK(cudaMalloc(&d_input, size_arr)); - CUDA_CHECK(cudaMalloc(>mp, n * sizeof(T))); - CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), size_arr, cudaMemcpyHostToDevice)); + int blockSize = 256; + int gridSize = (n + 255) / 256; - kTopSort_Work(d_input, gtmp, n); + init_input<<>>(d_input, n, h_input.size()); + CUDA_CHECK(cudaDeviceSynchronize()); + + + for (int k = 2; k <= n; k <<= 1) { + for (int j = (k >> 1); j > 0; j >>= 1) { + bitonic_sort_kernel<<>>(d_input, n, 1, k, j); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } T result; - CUDA_CHECK(cudaMemcpy(&result, d_input + (n - k), size_T, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(&result, d_input + (k - 1), sizeof(T), cudaMemcpyDeviceToHost)); - CUDA_CHECK(cudaFree(gtmp)); CUDA_CHECK(cudaFree(d_input)); + return result; } From 365b3e83786eca96febfed15f2a90ec25d39f44f Mon Sep 17 00:00:00 2001 From: JBFYS-XD Date: Tue, 2 Sep 2025 16:13:42 +0800 Subject: [PATCH 4/5] JBFYS-XD: change uint8_t to unsigned char --- src/kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels.cu b/src/kernels.cu index 966025a..9f1ccf2 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -176,7 +176,7 @@ __global__ void flashAttentionKernel( if (batch_id >= batch_size || heads_id >= query_heads) return; - extern __shared__ uint8_t shared_mem[]; + extern __shared__ unsigned char shared_mem[]; T* scores = reinterpret_cast(shared_mem); T* smem = scores + src_seq_len; From 71d4063ac10e20de1b87d3e7d0fcf655be41e2a1 Mon Sep 17 00:00:00 2001 From: JBFYS-XD Date: Sat, 27 Sep 2025 21:46:24 +0800 Subject: [PATCH 5/5] JBFYS-XD: Remove redundant code branches to reduce code complexity --- src/kernels.cu | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/kernels.cu b/src/kernels.cu index 9f1ccf2..c81b460 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -25,16 +25,9 @@ __global__ void bitonic_sort_kernel(T* input, int n, int desc, int k, int j) { T valx = input[x]; T valy = input[y]; T val; - if ((valx > valy) == (desc ^ ((x & k) != 0))) { - if (x > y) { + if ((valx < valy) == (desc ^ ((x & k) != 0))) { input[x] = valy; input[y] = valx; - } - } else { - if (x < y) { - input[x] = valy; - input[y] = valx; - } } }