File size: 3,557 Bytes
bd421ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44066b7
bd421ea
 
 
 
 
 
 
 
 
 
 
 
44066b7
bd421ea
 
44066b7
 
 
bd421ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44066b7
 
 
 
 
 
 
 
bd421ea
 
 
 
 
44066b7
bd421ea
 
 
 
 
44066b7
 
 
 
bd421ea
 
 
 
 
 
44066b7
 
 
bd421ea
 
 
 
 
 
 
 
 
 
 
44066b7
bd421ea
 
 
 
 
44066b7
 
bd421ea
 
 
 
 
44066b7
bd421ea
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import fastwer
import numpy as np
from scipy.special import logsumexp


"""
-------------
 CTC decoder
-------------
"""

NINF = -1 * float("inf")
DEFAULT_EMISSION_THRESHOLD = 0.01


def _reconstruct(labels, blank=0):
    new_labels = []
    # merge same labels
    previous = None
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l
    # delete blank
    new_labels = [l for l in new_labels if l != blank]
    return new_labels


def beam_search_decode(emission_log_prob, blank=0, **kwargs):
    beam_size = kwargs["beam_size"]
    emission_threshold = kwargs.get(
        "emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD)
    )

    length, class_count = emission_log_prob.shape

    beams = [([], 0)]  # (prefix, accumulated_log_prob)
    for t in range(length):
        new_beams = []
        for prefix, accumulated_log_prob in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue
                new_prefix = prefix + [c]
                # log(p1 * p2) = log_p1 + log_p2
                new_accu_log_prob = accumulated_log_prob + log_prob
                new_beams.append((new_prefix, new_accu_log_prob))

        # sorted by accumulated_log_prob
        new_beams.sort(key=lambda x: x[1], reverse=True)
        beams = new_beams[:beam_size]

    # sum up beams to produce labels
    total_accu_log_prob = {}
    for prefix, accu_log_prob in beams:
        labels = tuple(_reconstruct(prefix, blank))
        # log(p1 + p2) = logsumexp([log_p1, log_p2])
        total_accu_log_prob[labels] = logsumexp(
            [accu_log_prob, total_accu_log_prob.get(labels, NINF)]
        )

    labels_beams = [
        (list(labels), accu_log_prob)
        for labels, accu_log_prob in total_accu_log_prob.items()
    ]
    labels_beams.sort(key=lambda x: x[1], reverse=True)
    labels = labels_beams[0][0]

    return labels


def greedy_decode(emission_log_prob, blank=0):
    labels = np.argmax(emission_log_prob, axis=-1)
    labels = _reconstruct(labels, blank=blank)
    return labels


def ctc_decode(
    log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25
):
    emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
    # size of emission_log_probs: (batch, length, class)

    decoded_list = []
    for emission_log_prob in emission_log_probs:
        if which_ctc_decoder == "beam_search":
            decoded = beam_search_decode(
                emission_log_prob, blank=blank, beam_size=beam_size
            )
        elif which_ctc_decoder == "greedy":
            decoded = greedy_decode(emission_log_prob, blank=blank)
        else:
            print(f"unidentified option for which_ctc_decoder : {which_ctc_decoder}")
            sys.exit(0)

        if label_2_char:
            decoded = [label_2_char[l] for l in decoded]
        decoded_list.append(decoded)
    return decoded_list


"""
--------------------
 Evaluation Metrics
--------------------
"""


def compute_wer_and_cer_for_batch(batch_preds, batch_gts):
    cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True)
    wer_batch = fastwer.score(batch_preds, batch_gts)
    return cer_batch, wer_batch


def compute_wer_and_cer_for_sample(str_pred, str_gt):
    cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True)
    wer_sample = fastwer.score_sent(str_pred, str_gt)
    return cer_sample, wer_sample