// Copyright (C) 2018-2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable __kernel void Convolution3x3( const __global half *in_param, const __global half *out, const __global half *w, int IW, int IH, int IC, int OW, int OH, int OC, int KX, int KY, int stride_x, int stride_y, int pad_x, int pad_y, int dilation_x, int dilation_y) { __local half in_local[8 * 1024]; __local half out_local[8 * 1024]; __local half w_local[8 * 1024]; const int sizePlane = IW * IH; event_t e1 = async_work_group_copy_2D2D( in_local, // dst in_param + get_group_id(0) * stride_y * IW, // src 3 * IW, // num_elements_per_line, IC, // num_lines, IW * IH - 3 * IW, // src_line_stride, 0, // dst_line_stride, 0); wait_group_events(1, &e1); const int sizeWeight = IC * 3 * 3; e1 = async_work_group_copy(w_local, w + get_group_id(1) * sizeWeight, sizeWeight, 0); wait_group_events(1, &e1); int oh = get_global_id(0); int oc = get_global_id(1); __local half *in = (__local half *)in_local + 1; int stride; int write_output = 0; __local half *src; if ((stride_x == 1) && (stride_y == 1)) { stride = OW / 8; write_output = 1; } if ((stride_x == 2) && (stride_y == 2)) { stride = OW / 4; write_output = 2; } for (int ow = 0; ow < stride; ow++) { float8 val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; for (int ic = 0; ic < IC; ++ic) { src = (__local half *)((__local half8 *)(in + ic * IW * 3) + ow); __local half *k = (__local half *)(w_local + ic * 3 * 3); half8 aux_in00 = *((__local half8 *)src - 1); half8 aux_in01 = *((__local half8 *)src + 0); half8 aux_in02 = *((__local half8 *)src + 1); half8 aux_in10 = *((__local half8 *)(src + IW) - 1); half8 aux_in11 = *((__local half8 *)(src + IW) + 0); half8 aux_in12 = *((__local half8 *)(src + IW) + 1); half8 aux_in20 = *((__local half8 *)(src + IW * 2) - 1); half8 aux_in21 = *((__local half8 *)(src + IW * 2) + 0); half8 aux_in22 = *((__local half8 *)(src + IW * 2) + 1); short8 in00 = *((short8 *)&aux_in00); short8 in01 = *((short8 *)&aux_in01); short8 in02 = *((short8 *)&aux_in02); short8 in10 = *((short8 *)&aux_in10); short8 in11 = *((short8 *)&aux_in11); short8 in12 = *((short8 *)&aux_in12); short8 in20 = *((short8 *)&aux_in20); short8 in21 = *((short8 *)&aux_in21); short8 in22 = *((short8 *)&aux_in22); short8 aux_aux00 = __builtin_shave_cmu_alignvec_rri_short8(in00, in01, 14); short8 aux_aux01 = in01; short8 aux_aux02 = __builtin_shave_cmu_alignvec_rri_short8(in01, in02, 2); short8 aux_aux10 = __builtin_shave_cmu_alignvec_rri_short8(in10, in11, 14); short8 aux_aux11 = in11; short8 aux_aux12 = __builtin_shave_cmu_alignvec_rri_short8(in11, in12, 2); short8 aux_aux20 = __builtin_shave_cmu_alignvec_rri_short8(in20, in21, 14); short8 aux_aux21 = in21; short8 aux_aux22 = __builtin_shave_cmu_alignvec_rri_short8(in21, in22, 2); half8 aux00 = *((half8 *)&aux_aux00); half8 aux01 = *((half8 *)&aux_aux01); half8 aux02 = *((half8 *)&aux_aux02); half8 aux10 = *((half8 *)&aux_aux10); half8 aux11 = *((half8 *)&aux_aux11); half8 aux12 = *((half8 *)&aux_aux12); half8 aux20 = *((half8 *)&aux_aux20); half8 aux21 = *((half8 *)&aux_aux21); half8 aux22 = *((half8 *)&aux_aux22); half8 w00 = (half8)(*(k + 0)); half8 w01 = (half8)(*(k + 1)); half8 w02 = (half8)(*(k + 2)); half8 w10 = (half8)(*(k + 3)); half8 w11 = (half8)(*(k + 4)); half8 w12 = (half8)(*(k + 5)); half8 w20 = (half8)(*(k + 6)); half8 w21 = (half8)(*(k + 7)); half8 w22 = (half8)(*(k + 8)); val += convert_float8(aux00) * convert_float8(w00); val += convert_float8(aux01) * convert_float8(w01); val += convert_float8(aux02) * convert_float8(w02); val += convert_float8(aux10) * convert_float8(w10); val += convert_float8(aux11) * convert_float8(w11); val += convert_float8(aux12) * convert_float8(w12); val += convert_float8(aux20) * convert_float8(w20); val += convert_float8(aux21) * convert_float8(w21); val += convert_float8(aux22) * convert_float8(w22); } if (write_output == 2) *((__local half4 *)(out_local) + ow) = convert_half4(val.s0246); if (write_output == 1) *((__local half8 *)(out_local) + ow) = convert_half8(val); } for (int ow = OW & ~(0x7); ow < OW; ow++) { float val = 0.0f; for (int ic = 0; ic < IC; ++ic) { for (int ky = 0; ky < 3; ++ky) { for (int kx = 0; kx < 3; ++kx) { int iw = ow * stride_x - pad_x + kx * dilation_x; int ih = oh * stride_y - pad_y + ky * dilation_y; val += convert_float(in[ic * IW * 3 + (ky * dilation_y) * IW + iw]) * convert_float(w_local[ic * 3 * 3 + ky * 3 + kx]); } } } out_local[ow] = convert_half(val); } barrier(CLK_LOCAL_MEM_FENCE); event_t e2 = async_work_group_copy( out + get_group_id(1) * OW * OH + get_group_id(0) * OW, out_local, OW, 0); wait_group_events(1, &e2); }