Spaces:
Runtime error
Runtime error
| // Copyright (C) 2018-2022 Intel Corporation | |
| // SPDX-License-Identifier: Apache-2.0 | |
| // | |
| __global half *find(__global const half *begin, __global const half *end, half value) | |
| { | |
| while (begin != end) { | |
| if (*begin == value) { | |
| return begin; | |
| } | |
| ++begin; | |
| } | |
| return end; | |
| } | |
| __kernel void CTCDecoder( | |
| __global half *restrict probabilities, | |
| __global half *restrict sequence_indicators, | |
| __global half *restrict output, | |
| int width, | |
| int height, | |
| int channels) | |
| { | |
| __local half local_src[88 * 1 * 77]; | |
| __local half local_dst[88 * 1]; | |
| event_t e1 = async_work_group_copy_2D2D( | |
| local_src, // dst | |
| probabilities, // src | |
| width, // num_elements_per_line, | |
| height * channels, // num_lines, | |
| width * (height - 1), // src_line_stride, | |
| width * (height - 1), // dst_line_stride, | |
| 0); | |
| wait_group_events(1, &e1); | |
| const int T = channels; // Time | |
| const int B = height; // Batches | |
| const int C = width; // Chars | |
| for (int i = 0; i < B * T; i++) { | |
| local_dst[i] = -1.h; | |
| } | |
| int output_index = 0; | |
| for (int b = 0; b < B; ++b) { | |
| __global const half *restrict seq_ind = sequence_indicators + b * T; | |
| const int seq_len = find(seq_ind + 1, seq_ind + T, 0.h) - seq_ind; | |
| const int time = min(seq_len, T); | |
| int prev_class_idx = -1; | |
| for (int t = 0; t < time; ++t) { | |
| __local const half *restrict probs = local_src + b * C + t * C * B; | |
| int max_class_idx = 0; | |
| half max_prob = probs[0]; | |
| for (int c = 1; c < C; ++c) { | |
| const half prob = probs[c]; | |
| if (prob > max_prob) { | |
| max_class_idx = c; | |
| max_prob = prob; | |
| } | |
| } | |
| if (max_class_idx < C - 1 && max_class_idx != prev_class_idx) { | |
| local_dst[b * T + output_index] = (half)max_class_idx; | |
| output_index++; | |
| } | |
| prev_class_idx = max_class_idx; | |
| } | |
| } | |
| barrier(CLK_LOCAL_MEM_FENCE); | |
| event_t e2 = async_work_group_copy_2D2D( | |
| output, // dst | |
| local_dst, // src | |
| channels, // num_elements_per_line, | |
| height, // num_lines, | |
| 0, // src_line_stride, | |
| 0, // dst_line_stride, | |
| 0); | |
| wait_group_events(1, &e2); | |
| } | |