Skip to content

Commit 5882e88

Browse files
committed
Remove shmem to fix bug in HIP vvar_grad kernel
1 parent 63bb6cd commit 5882e88

File tree

1 file changed

+27
-55
lines changed
  • src/xc_integrator/local_work_driver/device/hip/kernels

1 file changed

+27
-55
lines changed

src/xc_integrator/local_work_driver/device/hip/kernels/uvvars.hip

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
namespace GauXC {
1515

1616
#ifdef __HIP_PLATFORM_NVIDIA__
17-
#define VVAR_KERNEL_SM_BLOCK 32
1817
#define GGA_KERNEL_SM_WARPS 16
1918
#define MGGA_KERNEL_SM_BLOCK 32
2019
#else
21-
#define VVAR_KERNEL_SM_BLOCK 16
2220
#define GGA_KERNEL_SM_WARPS 8
2321
#define MGGA_KERNEL_SM_BLOCK 16
2422
#endif
@@ -501,8 +499,6 @@ __global__ void eval_vvar_grad_kern( size_t ntasks,
501499
double* den_y_eval_device = nullptr;
502500
double* den_z_eval_device = nullptr;
503501

504-
constexpr auto warp_size = hip::warp_size;
505-
506502
if constexpr (den_select == DEN_S) {
507503
den_eval_device = task.den_s;
508504
den_x_eval_device = task.dden_sx;
@@ -535,63 +531,39 @@ __global__ void eval_vvar_grad_kern( size_t ntasks,
535531

536532
const auto* den_basis_prod_device = task.zmat;
537533

538-
__shared__ double den_shared[4][warp_size][VVAR_KERNEL_SM_BLOCK+1];
534+
double den_reg = 0.;
535+
double dx_reg = 0.;
536+
double dy_reg = 0.;
537+
double dz_reg = 0.;
539538

540-
for ( int bid_x = blockIdx.x * blockDim.x;
541-
bid_x < nbf;
542-
bid_x += blockDim.x * gridDim.x ) {
543-
544-
for ( int bid_y = blockIdx.y * VVAR_KERNEL_SM_BLOCK;
545-
bid_y < npts;
546-
bid_y += VVAR_KERNEL_SM_BLOCK * gridDim.y ) {
547-
548-
for (int sm_y = threadIdx.y; sm_y < VVAR_KERNEL_SM_BLOCK; sm_y += blockDim.y) {
549-
den_shared[0][threadIdx.x][sm_y] = 0.;
550-
den_shared[1][threadIdx.x][sm_y] = 0.;
551-
den_shared[2][threadIdx.x][sm_y] = 0.;
552-
den_shared[3][threadIdx.x][sm_y] = 0.;
539+
int ipt = blockIdx.x * blockDim.x + threadIdx.x;
553540

554-
if (bid_y + threadIdx.x < npts and bid_x + sm_y < nbf) {
555-
const double* db_col = den_basis_prod_device + (bid_x + sm_y)*npts;
556-
const double* bf_col = basis_eval_device + (bid_x + sm_y)*npts;
557-
const double* bf_x_col = dbasis_x_eval_device + (bid_x + sm_y)*npts;
558-
const double* bf_y_col = dbasis_y_eval_device + (bid_x + sm_y)*npts;
559-
const double* bf_z_col = dbasis_z_eval_device + (bid_x + sm_y)*npts;
541+
if (ipt < npts) {
560542

561-
den_shared[0][threadIdx.x][sm_y] = bf_col [ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
562-
den_shared[1][threadIdx.x][sm_y] = bf_x_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
563-
den_shared[2][threadIdx.x][sm_y] = bf_y_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
564-
den_shared[3][threadIdx.x][sm_y] = bf_z_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
565-
}
566-
}
567-
__syncthreads();
543+
// Have each thread accumulate its own reduction result into a register.
544+
// There's no real _need_ for LDS because the reductions are small and
545+
// therefore can be done without sharing.
546+
for( int ibf = 0; ibf < nbf; ibf++ ) {
568547

569548

570-
for (int sm_y = threadIdx.y; sm_y < VVAR_KERNEL_SM_BLOCK; sm_y += blockDim.y) {
571-
const int tid_y = bid_y + sm_y;
572-
double den_reg = den_shared[0][sm_y][threadIdx.x];
573-
double dx_reg = den_shared[1][sm_y][threadIdx.x];
574-
double dy_reg = den_shared[2][sm_y][threadIdx.x];
575-
double dz_reg = den_shared[3][sm_y][threadIdx.x];
549+
const double* bf_col = basis_eval_device + ibf*npts;
550+
const double* bf_x_col = dbasis_x_eval_device + ibf*npts;
551+
const double* bf_y_col = dbasis_y_eval_device + ibf*npts;
552+
const double* bf_z_col = dbasis_z_eval_device + ibf*npts;
553+
const double* db_col = den_basis_prod_device + ibf*npts;
576554

577-
// Warp blocks are stored col major
578-
den_reg = hip::warp_reduce_sum<warp_size>( den_reg );
579-
dx_reg = 2. * hip::warp_reduce_sum<warp_size>( dx_reg );
580-
dy_reg = 2. * hip::warp_reduce_sum<warp_size>( dy_reg );
581-
dz_reg = 2. * hip::warp_reduce_sum<warp_size>( dz_reg );
555+
den_reg += bf_col[ ipt ] * db_col[ ipt ];
556+
dx_reg += 2 * bf_x_col[ ipt ] * db_col[ ipt ];
557+
dy_reg += 2 * bf_y_col[ ipt ] * db_col[ ipt ];
558+
dz_reg += 2 * bf_z_col[ ipt ] * db_col[ ipt ];
559+
}
582560

583561

584-
if( threadIdx.x == 0 and tid_y < npts ) {
585-
atomicAdd( den_eval_device + tid_y, den_reg );
586-
atomicAdd( den_x_eval_device + tid_y, dx_reg );
587-
atomicAdd( den_y_eval_device + tid_y, dy_reg );
588-
atomicAdd( den_z_eval_device + tid_y, dz_reg );
589-
}
590-
}
591-
__syncthreads();
592-
}
562+
den_eval_device [ipt] = den_reg;
563+
den_x_eval_device [ipt] = dx_reg ;
564+
den_y_eval_device [ipt] = dy_reg ;
565+
den_z_eval_device [ipt] = dz_reg ;
593566
}
594-
595567
}
596568

597569

@@ -656,9 +628,9 @@ void eval_vvar( size_t ntasks, int32_t nbf_max, int32_t npts_max, bool do_grad,
656628
dim3 threads;
657629
dim3 blocks;
658630
if( do_grad ) {
659-
threads = dim3( hip::warp_size, hip::max_warps_per_thread_block / 2, 1 );
660-
blocks = dim3( std::min(uint64_t(4), util::div_ceil( nbf_max, 4 )),
661-
std::min(uint64_t(16), util::div_ceil( nbf_max, 16 )),
631+
threads = dim3(hip::max_warps_per_thread_block, 1, 1);
632+
blocks = dim3( util::div_ceil( npts_max, threads.x ),
633+
1,
662634
ntasks );
663635
} else {
664636
threads = dim3( hip::warp_size, hip::max_warps_per_thread_block, 1 );

0 commit comments

Comments
 (0)