import torch import fastwer import numpy as np from scipy.special import logsumexp """ ------------- CTC decoder ------------- """ NINF = -1 * float("inf") DEFAULT_EMISSION_THRESHOLD = 0.01 def _reconstruct(labels, blank=0): new_labels = [] # merge same labels previous = None for l in labels: if l != previous: new_labels.append(l) previous = l # delete blank new_labels = [l for l in new_labels if l != blank] return new_labels def beam_search_decode(emission_log_prob, blank=0, **kwargs): beam_size = kwargs["beam_size"] emission_threshold = kwargs.get( "emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD) ) length, class_count = emission_log_prob.shape beams = [([], 0)] # (prefix, accumulated_log_prob) for t in range(length): new_beams = [] for prefix, accumulated_log_prob in beams: for c in range(class_count): log_prob = emission_log_prob[t, c] if log_prob < emission_threshold: continue new_prefix = prefix + [c] # log(p1 * p2) = log_p1 + log_p2 new_accu_log_prob = accumulated_log_prob + log_prob new_beams.append((new_prefix, new_accu_log_prob)) # sorted by accumulated_log_prob new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] # sum up beams to produce labels total_accu_log_prob = {} for prefix, accu_log_prob in beams: labels = tuple(_reconstruct(prefix, blank)) # log(p1 + p2) = logsumexp([log_p1, log_p2]) total_accu_log_prob[labels] = logsumexp( [accu_log_prob, total_accu_log_prob.get(labels, NINF)] ) labels_beams = [ (list(labels), accu_log_prob) for labels, accu_log_prob in total_accu_log_prob.items() ] labels_beams.sort(key=lambda x: x[1], reverse=True) labels = labels_beams[0][0] return labels def greedy_decode(emission_log_prob, blank=0): labels = np.argmax(emission_log_prob, axis=-1) labels = _reconstruct(labels, blank=blank) return labels def ctc_decode( log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25 ): emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2)) # size of emission_log_probs: (batch, length, class) decoded_list = [] for emission_log_prob in emission_log_probs: if which_ctc_decoder == "beam_search": decoded = beam_search_decode( emission_log_prob, blank=blank, beam_size=beam_size ) elif which_ctc_decoder == "greedy": decoded = greedy_decode(emission_log_prob, blank=blank) else: print(f"unidentified option for which_ctc_decoder : {which_ctc_decoder}") sys.exit(0) if label_2_char: decoded = [label_2_char[l] for l in decoded] decoded_list.append(decoded) return decoded_list """ -------------------- Evaluation Metrics -------------------- """ def compute_wer_and_cer_for_batch(batch_preds, batch_gts): cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True) wer_batch = fastwer.score(batch_preds, batch_gts) return cer_batch, wer_batch def compute_wer_and_cer_for_sample(str_pred, str_gt): cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True) wer_sample = fastwer.score_sent(str_pred, str_gt) return cer_sample, wer_sample