Spaces:
Runtime error
Runtime error
File size: 5,986 Bytes
81efcf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable
__kernel void Convolution3x3(
const __global half *in_param,
const __global half *out,
const __global half *w,
int IW,
int IH,
int IC,
int OW,
int OH,
int OC,
int KX,
int KY,
int stride_x,
int stride_y,
int pad_x,
int pad_y,
int dilation_x,
int dilation_y)
{
__local half in_local[8 * 1024];
__local half out_local[8 * 1024];
__local half w_local[8 * 1024];
const int sizePlane = IW * IH;
event_t e1 = async_work_group_copy_2D2D(
in_local, // dst
in_param + get_group_id(0) * stride_y * IW, // src
3 * IW, // num_elements_per_line,
IC, // num_lines,
IW * IH - 3 * IW, // src_line_stride,
0, // dst_line_stride,
0);
wait_group_events(1, &e1);
const int sizeWeight = IC * 3 * 3;
e1 = async_work_group_copy(w_local, w + get_group_id(1) * sizeWeight, sizeWeight, 0);
wait_group_events(1, &e1);
int oh = get_global_id(0);
int oc = get_global_id(1);
__local half *in = (__local half *)in_local + 1;
int stride;
int write_output = 0;
__local half *src;
if ((stride_x == 1) && (stride_y == 1)) {
stride = OW / 8;
write_output = 1;
}
if ((stride_x == 2) && (stride_y == 2)) {
stride = OW / 4;
write_output = 2;
}
for (int ow = 0; ow < stride; ow++) {
float8 val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
for (int ic = 0; ic < IC; ++ic) {
src = (__local half *)((__local half8 *)(in + ic * IW * 3) + ow);
__local half *k = (__local half *)(w_local + ic * 3 * 3);
half8 aux_in00 = *((__local half8 *)src - 1);
half8 aux_in01 = *((__local half8 *)src + 0);
half8 aux_in02 = *((__local half8 *)src + 1);
half8 aux_in10 = *((__local half8 *)(src + IW) - 1);
half8 aux_in11 = *((__local half8 *)(src + IW) + 0);
half8 aux_in12 = *((__local half8 *)(src + IW) + 1);
half8 aux_in20 = *((__local half8 *)(src + IW * 2) - 1);
half8 aux_in21 = *((__local half8 *)(src + IW * 2) + 0);
half8 aux_in22 = *((__local half8 *)(src + IW * 2) + 1);
short8 in00 = *((short8 *)&aux_in00);
short8 in01 = *((short8 *)&aux_in01);
short8 in02 = *((short8 *)&aux_in02);
short8 in10 = *((short8 *)&aux_in10);
short8 in11 = *((short8 *)&aux_in11);
short8 in12 = *((short8 *)&aux_in12);
short8 in20 = *((short8 *)&aux_in20);
short8 in21 = *((short8 *)&aux_in21);
short8 in22 = *((short8 *)&aux_in22);
short8 aux_aux00 = __builtin_shave_cmu_alignvec_rri_short8(in00, in01, 14);
short8 aux_aux01 = in01;
short8 aux_aux02 = __builtin_shave_cmu_alignvec_rri_short8(in01, in02, 2);
short8 aux_aux10 = __builtin_shave_cmu_alignvec_rri_short8(in10, in11, 14);
short8 aux_aux11 = in11;
short8 aux_aux12 = __builtin_shave_cmu_alignvec_rri_short8(in11, in12, 2);
short8 aux_aux20 = __builtin_shave_cmu_alignvec_rri_short8(in20, in21, 14);
short8 aux_aux21 = in21;
short8 aux_aux22 = __builtin_shave_cmu_alignvec_rri_short8(in21, in22, 2);
half8 aux00 = *((half8 *)&aux_aux00);
half8 aux01 = *((half8 *)&aux_aux01);
half8 aux02 = *((half8 *)&aux_aux02);
half8 aux10 = *((half8 *)&aux_aux10);
half8 aux11 = *((half8 *)&aux_aux11);
half8 aux12 = *((half8 *)&aux_aux12);
half8 aux20 = *((half8 *)&aux_aux20);
half8 aux21 = *((half8 *)&aux_aux21);
half8 aux22 = *((half8 *)&aux_aux22);
half8 w00 = (half8)(*(k + 0));
half8 w01 = (half8)(*(k + 1));
half8 w02 = (half8)(*(k + 2));
half8 w10 = (half8)(*(k + 3));
half8 w11 = (half8)(*(k + 4));
half8 w12 = (half8)(*(k + 5));
half8 w20 = (half8)(*(k + 6));
half8 w21 = (half8)(*(k + 7));
half8 w22 = (half8)(*(k + 8));
val += convert_float8(aux00) * convert_float8(w00);
val += convert_float8(aux01) * convert_float8(w01);
val += convert_float8(aux02) * convert_float8(w02);
val += convert_float8(aux10) * convert_float8(w10);
val += convert_float8(aux11) * convert_float8(w11);
val += convert_float8(aux12) * convert_float8(w12);
val += convert_float8(aux20) * convert_float8(w20);
val += convert_float8(aux21) * convert_float8(w21);
val += convert_float8(aux22) * convert_float8(w22);
}
if (write_output == 2) *((__local half4 *)(out_local) + ow) = convert_half4(val.s0246);
if (write_output == 1) *((__local half8 *)(out_local) + ow) = convert_half8(val);
}
for (int ow = OW & ~(0x7); ow < OW; ow++) {
float val = 0.0f;
for (int ic = 0; ic < IC; ++ic) {
for (int ky = 0; ky < 3; ++ky) {
for (int kx = 0; kx < 3; ++kx) {
int iw = ow * stride_x - pad_x + kx * dilation_x;
int ih = oh * stride_y - pad_y + ky * dilation_y;
val += convert_float(in[ic * IW * 3 + (ky * dilation_y) * IW + iw])
* convert_float(w_local[ic * 3 * 3 + ky * 3 + kx]);
}
}
}
out_local[ow] = convert_half(val);
}
barrier(CLK_LOCAL_MEM_FENCE);
event_t e2 = async_work_group_copy(
out + get_group_id(1) * OW * OH + get_group_id(0) * OW,
out_local,
OW,
0);
wait_group_events(1, &e2);
}
|