Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#define align_up_4(x) ((x + 3) & -4)
#define align_up_8(x) ((x + 7) & -8)

#define align_down_4(x) ((x) & -4)

#define mod_2(x) ((x) & 1)
#define mod_4(x) ((x) & 3)
#define mod_8(x) ((x) & 7)
Expand Down
60 changes: 60 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ bool in_bounds(
return true;
}

ivec4 load_input_block(
const int in_x4,
const int in_y,
const int ic4,
const Conv2dBlockExtents in_block_extents,
const ivec4 input_zps) {
if (!in_bounds(in_x4, in_y, ic4, in_block_extents)) {
return input_zps;
}
#ifdef PACKED_INT8_INPUT_BUFFER
const int buffer_idx =
in_y * in_block_extents.data_xz + in_x4 * in_block_extents.data.z + ic4;
return t_packed_int8_input[buffer_idx];
#else
return texelFetch(t_packed_int8_input, ivec3(in_x4, in_y, ic4), 0);
#endif
}

Int8InputWindow1D load_input_window(
const int w_start,
const int w_end,
Expand Down Expand Up @@ -100,6 +118,34 @@ ivec4 load_weight_block(
#endif
}

void conv1d_accumulate(
inout Int32Accum accum,
const ivec4 in_block,
const ivec4 weight_block,
const int kx,
const int out_x_start,
const int in_x_start) {
[[unroll]] for (int out_x = 0; out_x < 4; ++out_x) {
int in_x_offset = (out_x_start + out_x) * conv2d_params.stride.x
- conv2d_params.padding.x
+ (kx * conv2d_params.dilation.x);
in_x_offset -= in_x_start;

const bool in_bounds = in_x_offset >= 0 && in_x_offset < 4;

[[unroll]] for (int oc = 0; oc < 4; ++oc) {
int updated = accum.data[out_x][0][oc];
if (in_bounds) {
updated = dotPacked4x8AccSatEXT(
in_block[in_x_offset],
weight_block[oc],
updated);
}
accum.data[out_x][0][oc] = updated;
}
}
}

void perform_conv1d(
inout Int32Accum accum,
const Int8InputWindow1D input_window,
Expand Down Expand Up @@ -146,6 +192,20 @@ void printWeightBlock(const ivec4 weight_block) {
}
}

void printInputBlock(const ivec4 input_block) {
debugPrintfEXT("InputBlock contents: \\n");
for (int i = 0; i < 4; ++i) {
ivec4 unpacked = unpack_int8x4(input_block[i]);
debugPrintfEXT(
" [%d]: (%d, %d, %d, %d) \\n",
i,
unpacked.x,
unpacked.y,
unpacked.z,
unpacked.w);
}
}

#endif // DEBUG_MODE

#endif // CONV2D_Q8_UTILS_GLSLH
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,7 @@ void main() {
return;
}

const int out_w = mul_4(out_block_idx.data.x);
const int w_start =
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
conv2d_params.padding.x +
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
const int out_x_start = mul_4(out_block_idx.data.x);

Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);

Expand All @@ -99,24 +94,29 @@ void main() {

const int IC4_per_group = div_up_4(conv2d_params.in_channels_per_group);

const int n = mul_4(out_block_idx.data.z);
const int group_idx = n / conv2d_params.out_channels_per_group;
const int out_z = mul_4(out_block_idx.data.z);
const int group_idx = out_z / conv2d_params.out_channels_per_group;
const int group_ic4_offset = group_idx * IC4_per_group;

for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
const int h = out_block_idx.data.y * conv2d_params.stride.y -
const int in_y = out_block_idx.data.y * conv2d_params.stride.y -
conv2d_params.padding.y + ky * conv2d_params.dilation.y;

for (int ic4 = 0; ic4 < IC4_per_group; ic4++) {
Int8InputWindow1D int8_input_window = load_input_window(
w_start,
w_end,
h,
group_ic4_offset + ic4,
in_block_extents,
input_zps);
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
int in_x_load_start =
(out_x_start * conv2d_params.stride.x)
- conv2d_params.padding.x
+ (kx * conv2d_params.dilation.x);

for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
int in_x_load_end =
((out_x_start + 3) * conv2d_params.stride.x)
- conv2d_params.padding.x
+ (kx * conv2d_params.dilation.x);

in_x_load_start = align_down_4(in_x_load_start);
in_x_load_end = align_down_4(in_x_load_end);

for (int ic4 = 0; ic4 < IC4_per_group; ic4++) {
const ivec4 weight_block = load_weight_block(
ic4,
kx,
Expand All @@ -127,7 +127,22 @@ void main() {
conv2d_params.kernel_size.y,
out_block_extents.data.z);

perform_conv1d(out_accum, int8_input_window, weight_block, kx);
for (int in_x = in_x_load_start; in_x <= in_x_load_end; in_x+=4) {
const ivec4 in_block = load_input_block(
div_4(in_x),
in_y,
group_ic4_offset + ic4,
in_block_extents,
input_zps);

conv1d_accumulate(
out_accum,
in_block,
weight_block,
kx,
out_x_start,
in_x);
}
}
}
}
Expand Down
Loading