File size: 2,558 Bytes
81efcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable

__global half *find(__global const half *begin, __global const half *end, half value)
{
    while (begin != end) {
        if (*begin == value) {
            return begin;
        }
        ++begin;
    }
    return end;
}

__kernel void CTCDecoder(
    __global half *restrict probabilities,
    __global half *restrict sequence_indicators,
    __global half *restrict output,
    int width,
    int height,
    int channels)
{
    __local half local_src[88 * 1 * 77];
    __local half local_dst[88 * 1];

    event_t e1 = async_work_group_copy_2D2D(
        local_src, // dst
        probabilities, // src
        width, // num_elements_per_line,
        height * channels, // num_lines,
        width * (height - 1), // src_line_stride,
        width * (height - 1), // dst_line_stride,
        0);

    wait_group_events(1, &e1);

    const int T = channels; // Time
    const int B = height; // Batches
    const int C = width; // Chars

    #pragma unroll 4
    for (int i = 0; i < B * T; i++) {
        local_dst[i] = -1.h;
    }

    int output_index = 0;

    for (int b = 0; b < B; ++b) {
        __global const half *restrict seq_ind = sequence_indicators + b * T;
        const int seq_len = find(seq_ind + 1, seq_ind + T, 0.h) - seq_ind;
        const int time    = min(seq_len, T);

        int prev_class_idx = -1;

        #pragma unroll 4
        for (int t = 0; t < time; ++t) {
            __local const half *restrict probs = local_src + b * C + t * C * B;

            int max_class_idx = 0;
            half max_prob     = probs[0];
            for (int c = 1; c < C; ++c) {
                const half prob = probs[c];
                if (prob > max_prob) {
                    max_class_idx = c;
                    max_prob      = prob;
                }
            }

            if (max_class_idx < C - 1 && max_class_idx != prev_class_idx) {
                local_dst[b * T + output_index] = (half)max_class_idx;
                output_index++;
            }

            prev_class_idx = max_class_idx;
        }
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    event_t e2 = async_work_group_copy_2D2D(
        output, // dst
        local_dst, // src
        channels, // num_elements_per_line,
        height, // num_lines,
        0, // src_line_stride,
        0, // dst_line_stride,
        0);

    wait_group_events(1, &e2);
}