Zhu-FaceOnLive's picture
Upload 72 files
81efcf0
raw
history blame
No virus
5.99 kB
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable
__kernel void Convolution3x3(
const __global half *in_param,
const __global half *out,
const __global half *w,
int IW,
int IH,
int IC,
int OW,
int OH,
int OC,
int KX,
int KY,
int stride_x,
int stride_y,
int pad_x,
int pad_y,
int dilation_x,
int dilation_y)
{
__local half in_local[8 * 1024];
__local half out_local[8 * 1024];
__local half w_local[8 * 1024];
const int sizePlane = IW * IH;
event_t e1 = async_work_group_copy_2D2D(
in_local, // dst
in_param + get_group_id(0) * stride_y * IW, // src
3 * IW, // num_elements_per_line,
IC, // num_lines,
IW * IH - 3 * IW, // src_line_stride,
0, // dst_line_stride,
0);
wait_group_events(1, &e1);
const int sizeWeight = IC * 3 * 3;
e1 = async_work_group_copy(w_local, w + get_group_id(1) * sizeWeight, sizeWeight, 0);
wait_group_events(1, &e1);
int oh = get_global_id(0);
int oc = get_global_id(1);
__local half *in = (__local half *)in_local + 1;
int stride;
int write_output = 0;
__local half *src;
if ((stride_x == 1) && (stride_y == 1)) {
stride = OW / 8;
write_output = 1;
}
if ((stride_x == 2) && (stride_y == 2)) {
stride = OW / 4;
write_output = 2;
}
for (int ow = 0; ow < stride; ow++) {
float8 val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
for (int ic = 0; ic < IC; ++ic) {
src = (__local half *)((__local half8 *)(in + ic * IW * 3) + ow);
__local half *k = (__local half *)(w_local + ic * 3 * 3);
half8 aux_in00 = *((__local half8 *)src - 1);
half8 aux_in01 = *((__local half8 *)src + 0);
half8 aux_in02 = *((__local half8 *)src + 1);
half8 aux_in10 = *((__local half8 *)(src + IW) - 1);
half8 aux_in11 = *((__local half8 *)(src + IW) + 0);
half8 aux_in12 = *((__local half8 *)(src + IW) + 1);
half8 aux_in20 = *((__local half8 *)(src + IW * 2) - 1);
half8 aux_in21 = *((__local half8 *)(src + IW * 2) + 0);
half8 aux_in22 = *((__local half8 *)(src + IW * 2) + 1);
short8 in00 = *((short8 *)&aux_in00);
short8 in01 = *((short8 *)&aux_in01);
short8 in02 = *((short8 *)&aux_in02);
short8 in10 = *((short8 *)&aux_in10);
short8 in11 = *((short8 *)&aux_in11);
short8 in12 = *((short8 *)&aux_in12);
short8 in20 = *((short8 *)&aux_in20);
short8 in21 = *((short8 *)&aux_in21);
short8 in22 = *((short8 *)&aux_in22);
short8 aux_aux00 = __builtin_shave_cmu_alignvec_rri_short8(in00, in01, 14);
short8 aux_aux01 = in01;
short8 aux_aux02 = __builtin_shave_cmu_alignvec_rri_short8(in01, in02, 2);
short8 aux_aux10 = __builtin_shave_cmu_alignvec_rri_short8(in10, in11, 14);
short8 aux_aux11 = in11;
short8 aux_aux12 = __builtin_shave_cmu_alignvec_rri_short8(in11, in12, 2);
short8 aux_aux20 = __builtin_shave_cmu_alignvec_rri_short8(in20, in21, 14);
short8 aux_aux21 = in21;
short8 aux_aux22 = __builtin_shave_cmu_alignvec_rri_short8(in21, in22, 2);
half8 aux00 = *((half8 *)&aux_aux00);
half8 aux01 = *((half8 *)&aux_aux01);
half8 aux02 = *((half8 *)&aux_aux02);
half8 aux10 = *((half8 *)&aux_aux10);
half8 aux11 = *((half8 *)&aux_aux11);
half8 aux12 = *((half8 *)&aux_aux12);
half8 aux20 = *((half8 *)&aux_aux20);
half8 aux21 = *((half8 *)&aux_aux21);
half8 aux22 = *((half8 *)&aux_aux22);
half8 w00 = (half8)(*(k + 0));
half8 w01 = (half8)(*(k + 1));
half8 w02 = (half8)(*(k + 2));
half8 w10 = (half8)(*(k + 3));
half8 w11 = (half8)(*(k + 4));
half8 w12 = (half8)(*(k + 5));
half8 w20 = (half8)(*(k + 6));
half8 w21 = (half8)(*(k + 7));
half8 w22 = (half8)(*(k + 8));
val += convert_float8(aux00) * convert_float8(w00);
val += convert_float8(aux01) * convert_float8(w01);
val += convert_float8(aux02) * convert_float8(w02);
val += convert_float8(aux10) * convert_float8(w10);
val += convert_float8(aux11) * convert_float8(w11);
val += convert_float8(aux12) * convert_float8(w12);
val += convert_float8(aux20) * convert_float8(w20);
val += convert_float8(aux21) * convert_float8(w21);
val += convert_float8(aux22) * convert_float8(w22);
}
if (write_output == 2) *((__local half4 *)(out_local) + ow) = convert_half4(val.s0246);
if (write_output == 1) *((__local half8 *)(out_local) + ow) = convert_half8(val);
}
for (int ow = OW & ~(0x7); ow < OW; ow++) {
float val = 0.0f;
for (int ic = 0; ic < IC; ++ic) {
for (int ky = 0; ky < 3; ++ky) {
for (int kx = 0; kx < 3; ++kx) {
int iw = ow * stride_x - pad_x + kx * dilation_x;
int ih = oh * stride_y - pad_y + ky * dilation_y;
val += convert_float(in[ic * IW * 3 + (ky * dilation_y) * IW + iw])
* convert_float(w_local[ic * 3 * 3 + ky * 3 + kx]);
}
}
}
out_local[ow] = convert_half(val);
}
barrier(CLK_LOCAL_MEM_FENCE);
event_t e2 = async_work_group_copy(
out + get_group_id(1) * OW * OH + get_group_id(0) * OW,
out_local,
OW,
0);
wait_group_events(1, &e2);
}