import os import os.path as osp import sys import time from collections import defaultdict import matplotlib import numpy as np import soundfile as sf import torch from torch import nn import jiwer import matplotlib.pylab as plt def calc_wer(target, pred, ignore_indexes=[0]): target_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(target))))) pred_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(pred))))) target_str = ' '.join(target_chars) pred_str = ' '.join(pred_chars) error = jiwer.wer(target_str, pred_str) return error def drop_duplicated(chars): ret_chars = [chars[0]] for prev, curr in zip(chars[:-1], chars[1:]): if prev != curr: ret_chars.append(curr) return ret_chars def build_criterion(critic_params={}): criterion = { "ce": nn.CrossEntropyLoss(ignore_index=-1), "ctc": torch.nn.CTCLoss(**critic_params.get('ctc', {})), } return criterion def get_data_path_list(train_path=None, val_path=None): if train_path is None: train_path = "Data/train_list.txt" if val_path is None: val_path = "Data/val_list.txt" with open(train_path, 'r') as f: train_list = f.readlines() with open(val_path, 'r') as f: val_list = f.readlines() return train_list, val_list def plot_image(image): fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(image, aspect="auto", origin="lower", interpolation='none') fig.canvas.draw() plt.close() return fig