File size: 3,740 Bytes
81efcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable

#define USE_OPTIMIZED_ROUND

#ifdef USE_OPTIMIZED_ROUND
#define ROUND(x) ((int)((x) + 0.5f))
#else
#define ROUND(x) (int)(round(x))
#endif

inline int out_to_in(float ox, float f) { return (int)((ox + 0.5f) * f); }

void interpolationCHW_nn(__local half *psrc, __local half *pdst, int OW, int IW, int C, float rw, float rh)
{
    float alpha = rh / 2.0f - 0.5f;

    for (int w = 0; w < OW / 8; w++) {
        float fw0 = rw * (w * 8 + 0) + alpha;
        float fw1 = rw * (w * 8 + 1) + alpha;
        float fw2 = rw * (w * 8 + 2) + alpha;
        float fw3 = rw * (w * 8 + 3) + alpha;

        float fw4 = rw * (w * 8 + 4) + alpha;
        float fw5 = rw * (w * 8 + 5) + alpha;
        float fw6 = rw * (w * 8 + 6) + alpha;
        float fw7 = rw * (w * 8 + 7) + alpha;

        int iw0 = min((int)ROUND(fw0), IW - 1);
        int iw1 = min((int)ROUND(fw1), IW - 1);
        int iw2 = min((int)ROUND(fw2), IW - 1);
        int iw3 = min((int)ROUND(fw3), IW - 1);

        int iw4 = min((int)ROUND(fw4), IW - 1);
        int iw5 = min((int)ROUND(fw5), IW - 1);
        int iw6 = min((int)ROUND(fw6), IW - 1);
        int iw7 = min((int)ROUND(fw7), IW - 1);

        for (int c = 0; c < C; c++) {
            half8 val = {
                *((__local half *)(psrc + c * IW + iw0)),
                *((__local half *)(psrc + c * IW + iw1)),
                *((__local half *)(psrc + c * IW + iw2)),
                *((__local half *)(psrc + c * IW + iw3)),

                *((__local half *)(psrc + c * IW + iw4)),
                *((__local half *)(psrc + c * IW + iw5)),
                *((__local half *)(psrc + c * IW + iw6)),
                *((__local half *)(psrc + c * IW + iw7)),
            };
            *((__local half8 *)(pdst + c * OW + w * 8)) = val;
        }
    }

    for (int w = OW / 8 * 8; w < OW; w++) {
        float fw = rw * w + alpha;
        int iw0  = min((int)ROUND(fw), IW - 1);

        for (int c = 0; c < C; c++) {
            *((__local half *)(pdst + c * OW + w)) = *((__local half *)(psrc + c * IW + iw0));
        }
    }
}

kernel void resample_nearest(
    __global const half *restrict src,
    __global half *restrict dst,
    int iw,
    int ih,
    float factor,
    int ow,
    int oh,
    int channels)
{
    __local half local_src[14 * 1024];
    __local half local_dst[14 * 1024];

    const int oy_first = get_group_id(1) * get_local_size(1);
    const int oy_last  = (get_group_id(1) + 1) * get_local_size(1) - 1;
    const int iy_first = out_to_in(oy_first, 1.0 / factor);
    const int iy_last  = out_to_in(oy_last, 1.0 / factor);

    const int iy_size = iy_last - iy_first + 1;

    event_t e1 = async_work_group_copy_2D2D(
        local_src, // dst
        src + get_group_id(2) * channels * ih * iw + iy_first * iw, // src
        iy_size * iw, // num_elements_per_line,
        channels, // num_lines,
        ih * iw - iy_size * iw, // src_line_stride,
        0, // dst_line_stride,
        0);

    wait_group_events(1, &e1);

    interpolationCHW_nn(local_src, local_dst, ow, iw, channels, 1.0 / factor, 1.0 / factor);

    event_t e2 = async_work_group_copy_2D2D(
        dst + get_group_id(2) * channels * get_global_size(1) * ow + get_group_id(1) * get_local_size(1) * ow, // dst
        local_dst, // src
        get_local_size(1) * ow, // size_t num_elements_per_line,
        channels, // size_t num_lines,
        0, // size_t src_line_stride,
        get_global_size(1) * ow - get_local_size(1) * ow, // size_t dst_line_stride,
        0);

    wait_group_events(1, &e2);
}