|
14 | 14 | namespace GauXC { |
15 | 15 |
|
16 | 16 | #ifdef __HIP_PLATFORM_NVIDIA__ |
17 | | -#define VVAR_KERNEL_SM_BLOCK 32 |
18 | 17 | #define GGA_KERNEL_SM_WARPS 16 |
19 | 18 | #define MGGA_KERNEL_SM_BLOCK 32 |
20 | 19 | #else |
21 | | -#define VVAR_KERNEL_SM_BLOCK 16 |
22 | 20 | #define GGA_KERNEL_SM_WARPS 8 |
23 | 21 | #define MGGA_KERNEL_SM_BLOCK 16 |
24 | 22 | #endif |
@@ -501,8 +499,6 @@ __global__ void eval_vvar_grad_kern( size_t ntasks, |
501 | 499 | double* den_y_eval_device = nullptr; |
502 | 500 | double* den_z_eval_device = nullptr; |
503 | 501 |
|
504 | | - constexpr auto warp_size = hip::warp_size; |
505 | | - |
506 | 502 | if constexpr (den_select == DEN_S) { |
507 | 503 | den_eval_device = task.den_s; |
508 | 504 | den_x_eval_device = task.dden_sx; |
@@ -535,63 +531,39 @@ __global__ void eval_vvar_grad_kern( size_t ntasks, |
535 | 531 |
|
536 | 532 | const auto* den_basis_prod_device = task.zmat; |
537 | 533 |
|
538 | | - __shared__ double den_shared[4][warp_size][VVAR_KERNEL_SM_BLOCK+1]; |
| 534 | + double den_reg = 0.; |
| 535 | + double dx_reg = 0.; |
| 536 | + double dy_reg = 0.; |
| 537 | + double dz_reg = 0.; |
539 | 538 |
|
540 | | - for ( int bid_x = blockIdx.x * blockDim.x; |
541 | | - bid_x < nbf; |
542 | | - bid_x += blockDim.x * gridDim.x ) { |
543 | | - |
544 | | - for ( int bid_y = blockIdx.y * VVAR_KERNEL_SM_BLOCK; |
545 | | - bid_y < npts; |
546 | | - bid_y += VVAR_KERNEL_SM_BLOCK * gridDim.y ) { |
547 | | - |
548 | | - for (int sm_y = threadIdx.y; sm_y < VVAR_KERNEL_SM_BLOCK; sm_y += blockDim.y) { |
549 | | - den_shared[0][threadIdx.x][sm_y] = 0.; |
550 | | - den_shared[1][threadIdx.x][sm_y] = 0.; |
551 | | - den_shared[2][threadIdx.x][sm_y] = 0.; |
552 | | - den_shared[3][threadIdx.x][sm_y] = 0.; |
| 539 | + int ipt = blockIdx.x * blockDim.x + threadIdx.x; |
553 | 540 |
|
554 | | - if (bid_y + threadIdx.x < npts and bid_x + sm_y < nbf) { |
555 | | - const double* db_col = den_basis_prod_device + (bid_x + sm_y)*npts; |
556 | | - const double* bf_col = basis_eval_device + (bid_x + sm_y)*npts; |
557 | | - const double* bf_x_col = dbasis_x_eval_device + (bid_x + sm_y)*npts; |
558 | | - const double* bf_y_col = dbasis_y_eval_device + (bid_x + sm_y)*npts; |
559 | | - const double* bf_z_col = dbasis_z_eval_device + (bid_x + sm_y)*npts; |
| 541 | + if (ipt < npts) { |
560 | 542 |
|
561 | | - den_shared[0][threadIdx.x][sm_y] = bf_col [ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ]; |
562 | | - den_shared[1][threadIdx.x][sm_y] = bf_x_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ]; |
563 | | - den_shared[2][threadIdx.x][sm_y] = bf_y_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ]; |
564 | | - den_shared[3][threadIdx.x][sm_y] = bf_z_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ]; |
565 | | - } |
566 | | - } |
567 | | - __syncthreads(); |
| 543 | + // Have each thread accumulate its own reduction result into a register. |
| 544 | + // There's no real _need_ for LDS because the reductions are small and |
| 545 | + // therefore can be done without sharing. |
| 546 | + for( int ibf = 0; ibf < nbf; ibf++ ) { |
568 | 547 |
|
569 | 548 |
|
570 | | - for (int sm_y = threadIdx.y; sm_y < VVAR_KERNEL_SM_BLOCK; sm_y += blockDim.y) { |
571 | | - const int tid_y = bid_y + sm_y; |
572 | | - double den_reg = den_shared[0][sm_y][threadIdx.x]; |
573 | | - double dx_reg = den_shared[1][sm_y][threadIdx.x]; |
574 | | - double dy_reg = den_shared[2][sm_y][threadIdx.x]; |
575 | | - double dz_reg = den_shared[3][sm_y][threadIdx.x]; |
| 549 | + const double* bf_col = basis_eval_device + ibf*npts; |
| 550 | + const double* bf_x_col = dbasis_x_eval_device + ibf*npts; |
| 551 | + const double* bf_y_col = dbasis_y_eval_device + ibf*npts; |
| 552 | + const double* bf_z_col = dbasis_z_eval_device + ibf*npts; |
| 553 | + const double* db_col = den_basis_prod_device + ibf*npts; |
576 | 554 |
|
577 | | - // Warp blocks are stored col major |
578 | | - den_reg = hip::warp_reduce_sum<warp_size>( den_reg ); |
579 | | - dx_reg = 2. * hip::warp_reduce_sum<warp_size>( dx_reg ); |
580 | | - dy_reg = 2. * hip::warp_reduce_sum<warp_size>( dy_reg ); |
581 | | - dz_reg = 2. * hip::warp_reduce_sum<warp_size>( dz_reg ); |
| 555 | + den_reg += bf_col[ ipt ] * db_col[ ipt ]; |
| 556 | + dx_reg += 2 * bf_x_col[ ipt ] * db_col[ ipt ]; |
| 557 | + dy_reg += 2 * bf_y_col[ ipt ] * db_col[ ipt ]; |
| 558 | + dz_reg += 2 * bf_z_col[ ipt ] * db_col[ ipt ]; |
| 559 | + } |
582 | 560 |
|
583 | 561 |
|
584 | | - if( threadIdx.x == 0 and tid_y < npts ) { |
585 | | - atomicAdd( den_eval_device + tid_y, den_reg ); |
586 | | - atomicAdd( den_x_eval_device + tid_y, dx_reg ); |
587 | | - atomicAdd( den_y_eval_device + tid_y, dy_reg ); |
588 | | - atomicAdd( den_z_eval_device + tid_y, dz_reg ); |
589 | | - } |
590 | | - } |
591 | | - __syncthreads(); |
592 | | - } |
| 562 | + den_eval_device [ipt] = den_reg; |
| 563 | + den_x_eval_device [ipt] = dx_reg ; |
| 564 | + den_y_eval_device [ipt] = dy_reg ; |
| 565 | + den_z_eval_device [ipt] = dz_reg ; |
593 | 566 | } |
594 | | - |
595 | 567 | } |
596 | 568 |
|
597 | 569 |
|
@@ -656,9 +628,9 @@ void eval_vvar( size_t ntasks, int32_t nbf_max, int32_t npts_max, bool do_grad, |
656 | 628 | dim3 threads; |
657 | 629 | dim3 blocks; |
658 | 630 | if( do_grad ) { |
659 | | - threads = dim3( hip::warp_size, hip::max_warps_per_thread_block / 2, 1 ); |
660 | | - blocks = dim3( std::min(uint64_t(4), util::div_ceil( nbf_max, 4 )), |
661 | | - std::min(uint64_t(16), util::div_ceil( nbf_max, 16 )), |
| 631 | + threads = dim3(hip::max_warps_per_thread_block, 1, 1); |
| 632 | + blocks = dim3( util::div_ceil( npts_max, threads.x ), |
| 633 | + 1, |
662 | 634 | ntasks ); |
663 | 635 | } else { |
664 | 636 | threads = dim3( hip::warp_size, hip::max_warps_per_thread_block, 1 ); |
|
0 commit comments