Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions brainstorm/handlers/pycuda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,49 @@ def avgpool2d_backward_batch(self, inputs, window, outputs, padding,
in_deltas,
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))

def slice_copy_strided(self, inputs, outputs, slice_shape_inputs, slice_shape_outputs):

_slice_copy_impl(inputs, outputs,
slice_shape_inputs,slice_shape_outputs, np.int32(len(slice_shape_inputs)/4),
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))


def handle_shape(self, shape):
shape = [np.int32(dim) for dim in shape]
while len(shape) < 5:
shape = [np.int32(1)] + shape
return shape

def strided_elementwise_inplace(self, inputs, idx,func):
shape = self.handle_shape(inputs.shape)

if func not in strided_inp_funcs:
raise Exception("Strided function not supported. \
Supported functions are: {0}"
.format(strided_inp_funcs.keys()))


strided_inp_funcs[func](inputs, np.int32(idx),
shape[0],shape[1], shape[2], shape[3], shape[4],
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))


def strided_elementwise(self, inputs, outputs, stride,func):
shape = self.handle_shape(inputs.shape)

if func not in strided_funcs:
raise Exception("Strided function not supported. \
Supported functions are: {0}"
.format(strided_inp_funcs.keys()))


strided_funcs[func](inputs, outputs, np.int32(stride),
shape[0],shape[1], shape[2], shape[3], shape[4],
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))

def avgpool2d_forward_batch(self, inputs, window, outputs, padding,
stride):
Expand Down Expand Up @@ -877,3 +920,117 @@ def tanh_deriv(self, x, y, dy, dx):
"""
_mod_avepool_bwd_fp32 = SourceModule(__avepool_bwd_fp32_kernel)
_avepool_bwd_fp32_impl = _mod_avepool_bwd_fp32.get_function("ave_pool_bwd")


__strided_elewise_inp = """
__global__ void strided_elementwise_inp_{0}(float *in,
int matrixIndex, int dim1, int dim2, int dim3, int dim4, int dim5)
{{
int index = blockIdx.x * blockDim.x + threadIdx.x;
int idx = 0;

for(int i = index; i < dim1*dim2*dim3*dim4; i+=blockDim.x * gridDim.x)
{{
idx = (dim1*dim2*dim3*dim4*matrixIndex) + i;

in[idx] = {1};
}}
}}
"""
__strided_logistic_inp = __strided_elewise_inp.format(
"logistic","1./(1.+__expf(in[idx]))")
_mod_strided_elewise_kernel_logistic = SourceModule(__strided_logistic_inp)
_strided_elewise_inp_logistic = _mod_strided_elewise_kernel_logistic.get_function(
"strided_elementwise_inp_logistic")

__strided_tanh_inp = __strided_elewise_inp.format("tanh","tanh(in[idx])")
_mod_strided_ele_kernel_tanh_inp = SourceModule(__strided_tanh_inp)
_strided_elewise_inp_tanh = _mod_strided_ele_kernel_tanh_inp.get_function(
"strided_elementwise_inp_tanh")

strided_inp_funcs = {}
strided_inp_funcs['logistic'] = _strided_elewise_inp_logistic
strided_inp_funcs['tanh'] = _strided_elewise_inp_tanh


__strided_elewise = """
__global__ void strided_elementwise_{0}(float *in, float *out,
int matrixIndex, int dim1, int dim2, int dim3, int dim4, int dim5)
{{
int index = blockIdx.x * blockDim.x + threadIdx.x;
int idx = 0;

for(int d5 = 0; d5 < dim5; d5++)
{{
for(int i = index; i < dim1*dim2*dim3*dim4; i+=blockDim.x * gridDim.x)
{{
idx = (dim1*dim2*dim3*dim4*d5) + i;

if(d5 == matrixIndex){{ out[idx] = {1}; }}
else{{ out[idx] = in[idx]; }}
}}
}}
}}
"""
__strided_logistic = __strided_elewise.format(
"logistic","1./(1.+__expf(in[idx]))")
_mod_strided_elewise_kernel_logistic = SourceModule(__strided_logistic)
_strided_elewise_logistic = _mod_strided_elewise_kernel_logistic.get_function(
"strided_elementwise_logistic")

__strided_tanh = __strided_elewise.format("tanh","tanh(in[idx])")
_mod_strided_ele_kernel_tanh = SourceModule(__strided_tanh)
_strided_elewise_tanh = _mod_strided_ele_kernel_tanh.get_function(
"strided_elementwise_tanh")

strided_funcs = {}
strided_funcs['logistic'] = _strided_elewise_logistic
strided_funcs['tanh'] = _strided_elewise_tanh

__slice_copy_kernel = """
__global__ void slice_copy(float *in, float *out,
float *from_shape, float *to_shape, int shapes)
{
int in_start = 0;
int in_length = 0;//segment length
int in_segments= 0;
int in_stride = 0;
int in_current_segment = 0;
int in_slice_idx = 0;

int out_start = 0;
int out_length = 0;
int out_segments = 0;
int out_stride = 0;
int out_current_segment = 0;
int out_slice_idx = 0;

for(int shape = 0; shape < shapes; shape++)
{
in_start = (int)from_shape[shape*4];
in_length = (int)from_shape[shape*4+1];
in_segments = (int)from_shape[shape*4+2];
in_stride = (int)from_shape[shape*4+3];

out_start = (int)to_shape[shape*4];
out_length = (int)to_shape[shape*4+1];
out_segments = (int)from_shape[shape*4+2];
out_stride = (int)to_shape[shape*4+3];

for(int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
idx < in_length*in_segments ; idx+=blockDim.x * gridDim.x)
{
in_current_segment = ((idx)/in_length);
out_current_segment = ((idx)/out_length);
in_slice_idx = in_start + idx + (in_current_segment*in_stride);
out_slice_idx = out_start + idx + (out_current_segment*out_stride);

out[out_slice_idx] = in[in_slice_idx];


}
}
}
"""
_mod_slice_copy_kernel = SourceModule(__slice_copy_kernel)
_slice_copy_impl = _mod_slice_copy_kernel.get_function("slice_copy")
132 changes: 132 additions & 0 deletions brainstorm/tests/test_handler_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from brainstorm.handlers import NumpyHandler
from brainstorm.optional import has_pycuda

import pycuda

# np.random.seed(1234)
dtype = np.float32
NO_CON = set()



def _conv2d_forward_batch(inputs, weights, bias, outputs, padding, stride):
"""
Loop-based implementation of 2D convolution to check against.
Expand Down Expand Up @@ -167,3 +170,132 @@ def test_conv2d_forward_batch_pycuda():
print("Expected:\n", true_outputs)
print("Obtained:\n", outputs)
assert passed


@pytest.mark.skipif(has_pycuda is False, reason='requires PyCUDA+scikit-cuda')
def test_strided_elementwise():
from brainstorm.handlers import PyCudaHandler
_h = PyCudaHandler()
rdm = np.random.RandomState(1345)

def get_rdm_array(shape, dims):
if dims == 2: return rdm.randn(shape[0],shape[1])
elif dims == 3: return rdm.randn(shape[0],shape[1], shape[2])
else: return rdm.randn(shape[0],shape[1], shape[2], shape[3])

for dims in range(2,5):
for i in range(10):
shape = rdm.randint(1,17,dims)
a1 = np.float32(get_rdm_array(shape, dims))
a2 = np.float32(get_rdm_array(shape, dims))
a3 = np.float32(get_rdm_array(shape, dims))
a = np.vstack([a1,a2,a3])
original_shape = a.shape
a = a.reshape([int(original_shape[0]/3)] + list(original_shape[1:])+[3])
b = np.zeros_like(a, dtype=np.float32)
A = _h.create_from_numpy(a)


idx = rdm.randint(0,2)
func = ['logistic', 'tanh'][idx]

_h.strided_elementwise_inplace(A, 1,func)
outputs = _h.get_numpy_copy(A).reshape(original_shape)

c1 = a1
c2 = 1./(1.+np.exp(a2)) if idx == 0 else np.tanh(a2)
c3 = a3
c = np.vstack([c1,c2,c3])

passed = np.allclose(outputs, c)
assert passed

def test_strided_elementwise_inplace():
from brainstorm.handlers import PyCudaHandler
_h = PyCudaHandler()
rdm = np.random.RandomState(1345)

def get_rdm_array(shape, dims):
if dims == 2: return rdm.randn(shape[0],shape[1])
elif dims == 3: return rdm.randn(shape[0],shape[1], shape[2])
else: return rdm.randn(shape[0],shape[1], shape[2], shape[3])

for dims in range(2,5):
for i in range(10):
shape = rdm.randint(1,17,dims)
a1 = np.float32(get_rdm_array(shape, dims))
a2 = np.float32(get_rdm_array(shape, dims))
a3 = np.float32(get_rdm_array(shape, dims))
a = np.vstack([a1,a2,a3])
original_shape = a.shape
a = a.reshape([int(original_shape[0]/3)] + list(original_shape[1:])+[3])
b = np.zeros_like(a, dtype=np.float32)
A = _h.create_from_numpy(a)

_h.strided_elementwise_inplace(A, 1,'logistic')
_h.strided_elementwise_inplace(A, 0,'tanh')
outputs = _h.get_numpy_copy(A).reshape(original_shape)

c1 = np.tanh(a1)
c2 = 1./(1.+np.exp(a2))
c3 = a3
c = np.vstack([c1,c2,c3])

passed = np.allclose(outputs, c)
assert passed



'''
@pytest.mark.skipif(has_pycuda is False, reason='requires PyCUDA+scikit-cuda')
def test_slice_copy_stride():
from brainstorm.handlers import PyCudaHandler
_h = PyCudaHandler()
#2 dim test
a = np.float32(np.random.rand(10,10))
start = 4
length = 2
segments = 3
stride = 1
slices = [start, length, segments, stride]
data = []
for seg in range(segments):
row = np.int32(start/a.shape[1])
offset = start - (row*a.shape[0])
data += a[row,offset + (length*seg) + (seg*stride):offset + (length*seg) + (seg*stride) + length].tolist()


s = np.array(data, dtype=np.float32)
A = _h.create_from_numpy(a)
S = _h.create_from_numpy(np.zeros_like(s,dtype=np.float32))
slices_A =_h.create_from_numpy(np.array(slices,dtype=np.float32))
slices_B =_h.create_from_numpy(np.array([0,length*segments,1,0],dtype=np.float32))
_h.slice_copy_strided(A,S, slices_A, slices_B)
outputs = _h.get_numpy_copy(S)
passed = np.allclose(outputs, s)
assert passed
#3 dim test
a = np.float32(np.random.rand(10,10,10))
start = 50
length = 6
segments = 4
stride = 5
slices = [start, length, segments, stride]
data = []
for seg in range(segments):
row = np.int32(start/(a.shape[1]*a.shape[2]))
col = np.int32(start/(a.shape[1]))
offset = start - (row*(a.shape[1]*a.shape[2])) - (col*(a.shape[1]))
data += a[row,col, offset + (length*seg) + (seg*stride):offset + (length*seg) + (seg*stride) + length].tolist()


s = np.array(data, dtype=np.float32)
A = _h.create_from_numpy(a)
S = _h.create_from_numpy(np.zeros_like(s,dtype=np.float32))
slices_A =_h.create_from_numpy(np.array(slices,dtype=np.float32))
slices_B =_h.create_from_numpy(np.array([0,length*segments,1,0],dtype=np.float32))
_h.slice_copy_strided(A,S, slices_A, slices_B)
outputs = _h.get_numpy_copy(S)
passed = np.allclose(outputs, s)
assert passed
'''