Spaces:
Runtime error
Runtime error
File size: 2,244 Bytes
81efcf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable
__kernel void grn(__global const half *restrict src_data, __global half *restrict dst_data, int C, float bias)
{
__local half src[8 * 1024];
__local half dst[8 * 1024];
const size_t index = get_group_id(0) * get_local_size(0) + get_group_id(1) * get_local_size(1) * get_global_size(0);
event_t e1 = async_work_group_copy_3D3D(
src, // dst
src_data + index, // src
get_local_size(0), // num_elements_per_line,
get_local_size(1), // num_lines,
get_global_size(0) - get_local_size(0), // src_line_stride,
0, // dst_line_stride,
C, // num_planes,
get_global_size(0) * (get_global_size(1) - get_local_size(1)), // src_plane_stride
0, // dst_plane_stride
0);
wait_group_events(1, &e1);
float variance = bias + 1e-9f;
#pragma unroll 8
for (int c = 0; c < C; c++) {
float val = (float)src[c * get_local_size(1) * get_local_size(0)
+ get_local_id(1) * get_local_size(0)
+ get_local_id(0)];
variance += val * val;
}
half hvariance = (half)(native_rsqrt((half)(variance / 16.f)) * 0.25f);
#pragma unroll 8
for (int c = 0; c < C; c++) {
dst[c * get_local_size(1) * get_local_size(0)
+ get_local_id(1) * get_local_size(0)
+ get_local_id(0)] =
src[c * get_local_size(1) * get_local_size(0)
+ get_local_id(1) * get_local_size(0) + get_local_id(0)] * hvariance;
}
barrier(CLK_LOCAL_MEM_FENCE);
event_t e2 = async_work_group_copy_3D3D(
dst_data + index, // src
dst, // dst
get_local_size(0), // num_elements_per_line,
get_local_size(1), // num_lines,
0, // src_line_stride,
get_global_size(0) - get_local_size(0), // dst_line_stride,
C, // num_planes,
0, // src_plane_stride
get_global_size(0) * (get_global_size(1) - get_local_size(1)), // dst_plane_stride
0);
wait_group_events(1, &e2);
}
|