|
import torch |
|
import fastwer |
|
import numpy as np |
|
from scipy.special import logsumexp |
|
|
|
|
|
""" |
|
------------- |
|
CTC decoder |
|
------------- |
|
""" |
|
|
|
NINF = -1 * float("inf") |
|
DEFAULT_EMISSION_THRESHOLD = 0.01 |
|
|
|
|
|
def _reconstruct(labels, blank=0): |
|
new_labels = [] |
|
|
|
previous = None |
|
for l in labels: |
|
if l != previous: |
|
new_labels.append(l) |
|
previous = l |
|
|
|
new_labels = [l for l in new_labels if l != blank] |
|
return new_labels |
|
|
|
|
|
def beam_search_decode(emission_log_prob, blank=0, **kwargs): |
|
beam_size = kwargs["beam_size"] |
|
emission_threshold = kwargs.get( |
|
"emission_threshold", np.log(DEFAULT_EMISSION_THRESHOLD) |
|
) |
|
|
|
length, class_count = emission_log_prob.shape |
|
|
|
beams = [([], 0)] |
|
for t in range(length): |
|
new_beams = [] |
|
for prefix, accumulated_log_prob in beams: |
|
for c in range(class_count): |
|
log_prob = emission_log_prob[t, c] |
|
if log_prob < emission_threshold: |
|
continue |
|
new_prefix = prefix + [c] |
|
|
|
new_accu_log_prob = accumulated_log_prob + log_prob |
|
new_beams.append((new_prefix, new_accu_log_prob)) |
|
|
|
|
|
new_beams.sort(key=lambda x: x[1], reverse=True) |
|
beams = new_beams[:beam_size] |
|
|
|
|
|
total_accu_log_prob = {} |
|
for prefix, accu_log_prob in beams: |
|
labels = tuple(_reconstruct(prefix, blank)) |
|
|
|
total_accu_log_prob[labels] = logsumexp( |
|
[accu_log_prob, total_accu_log_prob.get(labels, NINF)] |
|
) |
|
|
|
labels_beams = [ |
|
(list(labels), accu_log_prob) |
|
for labels, accu_log_prob in total_accu_log_prob.items() |
|
] |
|
labels_beams.sort(key=lambda x: x[1], reverse=True) |
|
labels = labels_beams[0][0] |
|
|
|
return labels |
|
|
|
|
|
def greedy_decode(emission_log_prob, blank=0): |
|
labels = np.argmax(emission_log_prob, axis=-1) |
|
labels = _reconstruct(labels, blank=blank) |
|
return labels |
|
|
|
|
|
def ctc_decode( |
|
log_probs, which_ctc_decoder="beam_search", label_2_char=None, blank=0, beam_size=25 |
|
): |
|
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2)) |
|
|
|
|
|
decoded_list = [] |
|
for emission_log_prob in emission_log_probs: |
|
if which_ctc_decoder == "beam_search": |
|
decoded = beam_search_decode( |
|
emission_log_prob, blank=blank, beam_size=beam_size |
|
) |
|
elif which_ctc_decoder == "greedy": |
|
decoded = greedy_decode(emission_log_prob, blank=blank) |
|
else: |
|
print(f"unidentified option for which_ctc_decoder : {which_ctc_decoder}") |
|
sys.exit(0) |
|
|
|
if label_2_char: |
|
decoded = [label_2_char[l] for l in decoded] |
|
decoded_list.append(decoded) |
|
return decoded_list |
|
|
|
|
|
""" |
|
-------------------- |
|
Evaluation Metrics |
|
-------------------- |
|
""" |
|
|
|
|
|
def compute_wer_and_cer_for_batch(batch_preds, batch_gts): |
|
cer_batch = fastwer.score(batch_preds, batch_gts, char_level=True) |
|
wer_batch = fastwer.score(batch_preds, batch_gts) |
|
return cer_batch, wer_batch |
|
|
|
|
|
def compute_wer_and_cer_for_sample(str_pred, str_gt): |
|
cer_sample = fastwer.score_sent(str_pred, str_gt, char_level=True) |
|
wer_sample = fastwer.score_sent(str_pred, str_gt) |
|
return cer_sample, wer_sample |
|
|