Spaces:
Runtime error
Runtime error
# Copyright 2017 Johns Hopkins University (Shinji Watanabe) | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
import argparse | |
import copy | |
import json | |
import logging | |
import os | |
import shutil | |
import tempfile | |
import numpy as np | |
import torch | |
# * -------------------- training iterator related -------------------- * | |
class CompareValueTrigger(object): | |
"""Trigger invoked when key value getting bigger or lower than before. | |
Args: | |
key (str) : Key of value. | |
compare_fn ((float, float) -> bool) : Function to compare the values. | |
trigger (tuple(int, str)) : Trigger that decide the comparison interval. | |
""" | |
def __init__(self, key, compare_fn, trigger=(1, "epoch")): | |
from chainer import training | |
self._key = key | |
self._best_value = None | |
self._interval_trigger = training.util.get_trigger(trigger) | |
self._init_summary() | |
self._compare_fn = compare_fn | |
def __call__(self, trainer): | |
"""Get value related to the key and compare with current value.""" | |
observation = trainer.observation | |
summary = self._summary | |
key = self._key | |
if key in observation: | |
summary.add({key: observation[key]}) | |
if not self._interval_trigger(trainer): | |
return False | |
stats = summary.compute_mean() | |
value = float(stats[key]) # copy to CPU | |
self._init_summary() | |
if self._best_value is None: | |
# initialize best value | |
self._best_value = value | |
return False | |
elif self._compare_fn(self._best_value, value): | |
return True | |
else: | |
self._best_value = value | |
return False | |
def _init_summary(self): | |
import chainer | |
self._summary = chainer.reporter.DictSummary() | |
try: | |
from chainer.training import extension | |
except ImportError: | |
PlotAttentionReport = None | |
else: | |
class PlotAttentionReport(extension.Extension): | |
"""Plot attention reporter. | |
Args: | |
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions): | |
Function of attention visualization. | |
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. | |
outdir (str): Directory to save figures. | |
converter (espnet.asr.*_backend.asr.CustomConverter): | |
Function to convert data. | |
device (int | torch.device): Device. | |
reverse (bool): If True, input and output length are reversed. | |
ikey (str): Key to access input | |
(for ASR/ST ikey="input", for MT ikey="output".) | |
iaxis (int): Dimension to access input | |
(for ASR/ST iaxis=0, for MT iaxis=1.) | |
okey (str): Key to access output | |
(for ASR/ST okey="input", MT okay="output".) | |
oaxis (int): Dimension to access output | |
(for ASR/ST oaxis=0, for MT oaxis=0.) | |
subsampling_factor (int): subsampling factor in encoder | |
""" | |
def __init__( | |
self, | |
att_vis_fn, | |
data, | |
outdir, | |
converter, | |
transform, | |
device, | |
reverse=False, | |
ikey="input", | |
iaxis=0, | |
okey="output", | |
oaxis=0, | |
subsampling_factor=1, | |
): | |
self.att_vis_fn = att_vis_fn | |
self.data = copy.deepcopy(data) | |
self.data_dict = {k: v for k, v in copy.deepcopy(data)} | |
# key is utterance ID | |
self.outdir = outdir | |
self.converter = converter | |
self.transform = transform | |
self.device = device | |
self.reverse = reverse | |
self.ikey = ikey | |
self.iaxis = iaxis | |
self.okey = okey | |
self.oaxis = oaxis | |
self.factor = subsampling_factor | |
if not os.path.exists(self.outdir): | |
os.makedirs(self.outdir) | |
def __call__(self, trainer): | |
"""Plot and save image file of att_ws matrix.""" | |
att_ws, uttid_list = self.get_attention_weights() | |
if isinstance(att_ws, list): # multi-encoder case | |
num_encs = len(att_ws) - 1 | |
# atts | |
for i in range(num_encs): | |
for idx, att_w in enumerate(att_ws[i]): | |
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( | |
self.outdir, | |
uttid_list[idx], | |
i + 1, | |
) | |
att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( | |
self.outdir, | |
uttid_list[idx], | |
i + 1, | |
) | |
np.save(np_filename.format(trainer), att_w) | |
self._plot_and_save_attention(att_w, filename.format(trainer)) | |
# han | |
for idx, att_w in enumerate(att_ws[num_encs]): | |
filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( | |
self.outdir, | |
uttid_list[idx], | |
) | |
att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( | |
self.outdir, | |
uttid_list[idx], | |
) | |
np.save(np_filename.format(trainer), att_w) | |
self._plot_and_save_attention( | |
att_w, filename.format(trainer), han_mode=True | |
) | |
else: | |
for idx, att_w in enumerate(att_ws): | |
filename = "%s/%s.ep.{.updater.epoch}.png" % ( | |
self.outdir, | |
uttid_list[idx], | |
) | |
att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( | |
self.outdir, | |
uttid_list[idx], | |
) | |
np.save(np_filename.format(trainer), att_w) | |
self._plot_and_save_attention(att_w, filename.format(trainer)) | |
def log_attentions(self, logger, step): | |
"""Add image files of att_ws matrix to the tensorboard.""" | |
att_ws, uttid_list = self.get_attention_weights() | |
if isinstance(att_ws, list): # multi-encoder case | |
num_encs = len(att_ws) - 1 | |
# atts | |
for i in range(num_encs): | |
for idx, att_w in enumerate(att_ws[i]): | |
att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
plot = self.draw_attention_plot(att_w) | |
logger.add_figure( | |
"%s_att%d" % (uttid_list[idx], i + 1), | |
plot.gcf(), | |
step, | |
) | |
# han | |
for idx, att_w in enumerate(att_ws[num_encs]): | |
att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
plot = self.draw_han_plot(att_w) | |
logger.add_figure( | |
"%s_han" % (uttid_list[idx]), | |
plot.gcf(), | |
step, | |
) | |
else: | |
for idx, att_w in enumerate(att_ws): | |
att_w = self.trim_attention_weight(uttid_list[idx], att_w) | |
plot = self.draw_attention_plot(att_w) | |
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) | |
def get_attention_weights(self): | |
"""Return attention weights. | |
Returns: | |
numpy.ndarray: attention weights. float. Its shape would be | |
differ from backend. | |
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2) | |
other case => (B, Lmax, Tmax). | |
* chainer-> (B, Lmax, Tmax) | |
""" | |
return_batch, uttid_list = self.transform(self.data, return_uttid=True) | |
batch = self.converter([return_batch], self.device) | |
if isinstance(batch, tuple): | |
att_ws = self.att_vis_fn(*batch) | |
else: | |
att_ws = self.att_vis_fn(**batch) | |
return att_ws, uttid_list | |
def trim_attention_weight(self, uttid, att_w): | |
"""Transform attention matrix with regard to self.reverse.""" | |
if self.reverse: | |
enc_key, enc_axis = self.okey, self.oaxis | |
dec_key, dec_axis = self.ikey, self.iaxis | |
else: | |
enc_key, enc_axis = self.ikey, self.iaxis | |
dec_key, dec_axis = self.okey, self.oaxis | |
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0]) | |
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0]) | |
if self.factor > 1: | |
enc_len //= self.factor | |
if len(att_w.shape) == 3: | |
att_w = att_w[:, :dec_len, :enc_len] | |
else: | |
att_w = att_w[:dec_len, :enc_len] | |
return att_w | |
def draw_attention_plot(self, att_w): | |
"""Plot the att_w matrix. | |
Returns: | |
matplotlib.pyplot: pyplot object with attention matrix image. | |
""" | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
plt.clf() | |
att_w = att_w.astype(np.float32) | |
if len(att_w.shape) == 3: | |
for h, aw in enumerate(att_w, 1): | |
plt.subplot(1, len(att_w), h) | |
plt.imshow(aw, aspect="auto") | |
plt.xlabel("Encoder Index") | |
plt.ylabel("Decoder Index") | |
else: | |
plt.imshow(att_w, aspect="auto") | |
plt.xlabel("Encoder Index") | |
plt.ylabel("Decoder Index") | |
plt.tight_layout() | |
return plt | |
def draw_han_plot(self, att_w): | |
"""Plot the att_w matrix for hierarchical attention. | |
Returns: | |
matplotlib.pyplot: pyplot object with attention matrix image. | |
""" | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
plt.clf() | |
if len(att_w.shape) == 3: | |
for h, aw in enumerate(att_w, 1): | |
legends = [] | |
plt.subplot(1, len(att_w), h) | |
for i in range(aw.shape[1]): | |
plt.plot(aw[:, i]) | |
legends.append("Att{}".format(i)) | |
plt.ylim([0, 1.0]) | |
plt.xlim([0, aw.shape[0]]) | |
plt.grid(True) | |
plt.ylabel("Attention Weight") | |
plt.xlabel("Decoder Index") | |
plt.legend(legends) | |
else: | |
legends = [] | |
for i in range(att_w.shape[1]): | |
plt.plot(att_w[:, i]) | |
legends.append("Att{}".format(i)) | |
plt.ylim([0, 1.0]) | |
plt.xlim([0, att_w.shape[0]]) | |
plt.grid(True) | |
plt.ylabel("Attention Weight") | |
plt.xlabel("Decoder Index") | |
plt.legend(legends) | |
plt.tight_layout() | |
return plt | |
def _plot_and_save_attention(self, att_w, filename, han_mode=False): | |
if han_mode: | |
plt = self.draw_han_plot(att_w) | |
else: | |
plt = self.draw_attention_plot(att_w) | |
plt.savefig(filename) | |
plt.close() | |
try: | |
from chainer.training import extension | |
except ImportError: | |
PlotCTCReport = None | |
else: | |
class PlotCTCReport(extension.Extension): | |
"""Plot CTC reporter. | |
Args: | |
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs): | |
Function of CTC visualization. | |
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. | |
outdir (str): Directory to save figures. | |
converter (espnet.asr.*_backend.asr.CustomConverter): | |
Function to convert data. | |
device (int | torch.device): Device. | |
reverse (bool): If True, input and output length are reversed. | |
ikey (str): Key to access input | |
(for ASR/ST ikey="input", for MT ikey="output".) | |
iaxis (int): Dimension to access input | |
(for ASR/ST iaxis=0, for MT iaxis=1.) | |
okey (str): Key to access output | |
(for ASR/ST okey="input", MT okay="output".) | |
oaxis (int): Dimension to access output | |
(for ASR/ST oaxis=0, for MT oaxis=0.) | |
subsampling_factor (int): subsampling factor in encoder | |
""" | |
def __init__( | |
self, | |
ctc_vis_fn, | |
data, | |
outdir, | |
converter, | |
transform, | |
device, | |
reverse=False, | |
ikey="input", | |
iaxis=0, | |
okey="output", | |
oaxis=0, | |
subsampling_factor=1, | |
): | |
self.ctc_vis_fn = ctc_vis_fn | |
self.data = copy.deepcopy(data) | |
self.data_dict = {k: v for k, v in copy.deepcopy(data)} | |
# key is utterance ID | |
self.outdir = outdir | |
self.converter = converter | |
self.transform = transform | |
self.device = device | |
self.reverse = reverse | |
self.ikey = ikey | |
self.iaxis = iaxis | |
self.okey = okey | |
self.oaxis = oaxis | |
self.factor = subsampling_factor | |
if not os.path.exists(self.outdir): | |
os.makedirs(self.outdir) | |
def __call__(self, trainer): | |
"""Plot and save image file of ctc prob.""" | |
ctc_probs, uttid_list = self.get_ctc_probs() | |
if isinstance(ctc_probs, list): # multi-encoder case | |
num_encs = len(ctc_probs) - 1 | |
for i in range(num_encs): | |
for idx, ctc_prob in enumerate(ctc_probs[i]): | |
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( | |
self.outdir, | |
uttid_list[idx], | |
i + 1, | |
) | |
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( | |
self.outdir, | |
uttid_list[idx], | |
i + 1, | |
) | |
np.save(np_filename.format(trainer), ctc_prob) | |
self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) | |
else: | |
for idx, ctc_prob in enumerate(ctc_probs): | |
filename = "%s/%s.ep.{.updater.epoch}.png" % ( | |
self.outdir, | |
uttid_list[idx], | |
) | |
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( | |
self.outdir, | |
uttid_list[idx], | |
) | |
np.save(np_filename.format(trainer), ctc_prob) | |
self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) | |
def log_ctc_probs(self, logger, step): | |
"""Add image files of ctc probs to the tensorboard.""" | |
ctc_probs, uttid_list = self.get_ctc_probs() | |
if isinstance(ctc_probs, list): # multi-encoder case | |
num_encs = len(ctc_probs) - 1 | |
for i in range(num_encs): | |
for idx, ctc_prob in enumerate(ctc_probs[i]): | |
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
plot = self.draw_ctc_plot(ctc_prob) | |
logger.add_figure( | |
"%s_ctc%d" % (uttid_list[idx], i + 1), | |
plot.gcf(), | |
step, | |
) | |
else: | |
for idx, ctc_prob in enumerate(ctc_probs): | |
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) | |
plot = self.draw_ctc_plot(ctc_prob) | |
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) | |
def get_ctc_probs(self): | |
"""Return CTC probs. | |
Returns: | |
numpy.ndarray: CTC probs. float. Its shape would be | |
differ from backend. (B, Tmax, vocab). | |
""" | |
return_batch, uttid_list = self.transform(self.data, return_uttid=True) | |
batch = self.converter([return_batch], self.device) | |
if isinstance(batch, tuple): | |
probs = self.ctc_vis_fn(*batch) | |
else: | |
probs = self.ctc_vis_fn(**batch) | |
return probs, uttid_list | |
def trim_ctc_prob(self, uttid, prob): | |
"""Trim CTC posteriors accoding to input lengths.""" | |
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0]) | |
if self.factor > 1: | |
enc_len //= self.factor | |
prob = prob[:enc_len] | |
return prob | |
def draw_ctc_plot(self, ctc_prob): | |
"""Plot the ctc_prob matrix. | |
Returns: | |
matplotlib.pyplot: pyplot object with CTC prob matrix image. | |
""" | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
ctc_prob = ctc_prob.astype(np.float32) | |
plt.clf() | |
topk_ids = np.argsort(ctc_prob, axis=1) | |
n_frames, vocab = ctc_prob.shape | |
times_probs = np.arange(n_frames) | |
plt.figure(figsize=(20, 8)) | |
# NOTE: index 0 is reserved for blank | |
for idx in set(topk_ids.reshape(-1).tolist()): | |
if idx == 0: | |
plt.plot( | |
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey" | |
) | |
else: | |
plt.plot(times_probs, ctc_prob[:, idx]) | |
plt.xlabel("Input [frame]", fontsize=12) | |
plt.ylabel("Posteriors", fontsize=12) | |
plt.xticks(list(range(0, int(n_frames) + 1, 10))) | |
plt.yticks(list(range(0, 2, 1))) | |
plt.tight_layout() | |
return plt | |
def _plot_and_save_ctc(self, ctc_prob, filename): | |
plt = self.draw_ctc_plot(ctc_prob) | |
plt.savefig(filename) | |
plt.close() | |
def restore_snapshot(model, snapshot, load_fn=None): | |
"""Extension to restore snapshot. | |
Returns: | |
An extension function. | |
""" | |
import chainer | |
from chainer import training | |
if load_fn is None: | |
load_fn = chainer.serializers.load_npz | |
def restore_snapshot(trainer): | |
_restore_snapshot(model, snapshot, load_fn) | |
return restore_snapshot | |
def _restore_snapshot(model, snapshot, load_fn=None): | |
if load_fn is None: | |
import chainer | |
load_fn = chainer.serializers.load_npz | |
load_fn(snapshot, model) | |
logging.info("restored from " + str(snapshot)) | |
def adadelta_eps_decay(eps_decay): | |
"""Extension to perform adadelta eps decay. | |
Args: | |
eps_decay (float): Decay rate of eps. | |
Returns: | |
An extension function. | |
""" | |
from chainer import training | |
def adadelta_eps_decay(trainer): | |
_adadelta_eps_decay(trainer, eps_decay) | |
return adadelta_eps_decay | |
def _adadelta_eps_decay(trainer, eps_decay): | |
optimizer = trainer.updater.get_optimizer("main") | |
# for chainer | |
if hasattr(optimizer, "eps"): | |
current_eps = optimizer.eps | |
setattr(optimizer, "eps", current_eps * eps_decay) | |
logging.info("adadelta eps decayed to " + str(optimizer.eps)) | |
# pytorch | |
else: | |
for p in optimizer.param_groups: | |
p["eps"] *= eps_decay | |
logging.info("adadelta eps decayed to " + str(p["eps"])) | |
def adam_lr_decay(eps_decay): | |
"""Extension to perform adam lr decay. | |
Args: | |
eps_decay (float): Decay rate of lr. | |
Returns: | |
An extension function. | |
""" | |
from chainer import training | |
def adam_lr_decay(trainer): | |
_adam_lr_decay(trainer, eps_decay) | |
return adam_lr_decay | |
def _adam_lr_decay(trainer, eps_decay): | |
optimizer = trainer.updater.get_optimizer("main") | |
# for chainer | |
if hasattr(optimizer, "lr"): | |
current_lr = optimizer.lr | |
setattr(optimizer, "lr", current_lr * eps_decay) | |
logging.info("adam lr decayed to " + str(optimizer.lr)) | |
# pytorch | |
else: | |
for p in optimizer.param_groups: | |
p["lr"] *= eps_decay | |
logging.info("adam lr decayed to " + str(p["lr"])) | |
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"): | |
"""Extension to take snapshot of the trainer for pytorch. | |
Returns: | |
An extension function. | |
""" | |
from chainer.training import extension | |
def torch_snapshot(trainer): | |
_torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun) | |
return torch_snapshot | |
def _torch_snapshot_object(trainer, target, filename, savefun): | |
from chainer.serializers import DictionarySerializer | |
# make snapshot_dict dictionary | |
s = DictionarySerializer() | |
s.save(trainer) | |
if hasattr(trainer.updater.model, "model"): | |
# (for TTS) | |
if hasattr(trainer.updater.model.model, "module"): | |
model_state_dict = trainer.updater.model.model.module.state_dict() | |
else: | |
model_state_dict = trainer.updater.model.model.state_dict() | |
else: | |
# (for ASR) | |
if hasattr(trainer.updater.model, "module"): | |
model_state_dict = trainer.updater.model.module.state_dict() | |
else: | |
model_state_dict = trainer.updater.model.state_dict() | |
snapshot_dict = { | |
"trainer": s.target, | |
"model": model_state_dict, | |
"optimizer": trainer.updater.get_optimizer("main").state_dict(), | |
} | |
# save snapshot dictionary | |
fn = filename.format(trainer) | |
prefix = "tmp" + fn | |
tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out) | |
tmppath = os.path.join(tmpdir, fn) | |
try: | |
savefun(snapshot_dict, tmppath) | |
shutil.move(tmppath, os.path.join(trainer.out, fn)) | |
finally: | |
shutil.rmtree(tmpdir) | |
def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55): | |
"""Adds noise from a standard normal distribution to the gradients. | |
The standard deviation (`sigma`) is controlled by the three hyper-parameters below. | |
`sigma` goes to zero (no noise) with more iterations. | |
Args: | |
model (torch.nn.model): Model. | |
iteration (int): Number of iterations. | |
duration (int) {100, 1000}: | |
Number of durations to control the interval of the `sigma` change. | |
eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`. | |
scale_factor (float) {0.55}: The scale of `sigma`. | |
""" | |
interval = (iteration // duration) + 1 | |
sigma = eta / interval**scale_factor | |
for param in model.parameters(): | |
if param.grad is not None: | |
_shape = param.grad.size() | |
noise = sigma * torch.randn(_shape).to(param.device) | |
param.grad += noise | |
# * -------------------- general -------------------- * | |
def get_model_conf(model_path, conf_path=None): | |
"""Get model config information by reading a model config file (model.json). | |
Args: | |
model_path (str): Model path. | |
conf_path (str): Optional model config path. | |
Returns: | |
list[int, int, dict[str, Any]]: Config information loaded from json file. | |
""" | |
if conf_path is None: | |
model_conf = os.path.dirname(model_path) + "/model.json" | |
else: | |
model_conf = conf_path | |
with open(model_conf, "rb") as f: | |
logging.info("reading a config file from " + model_conf) | |
confs = json.load(f) | |
if isinstance(confs, dict): | |
# for lm | |
args = confs | |
return argparse.Namespace(**args) | |
else: | |
# for asr, tts, mt | |
idim, odim, args = confs | |
return idim, odim, argparse.Namespace(**args) | |
def chainer_load(path, model): | |
"""Load chainer model parameters. | |
Args: | |
path (str): Model path or snapshot file path to be loaded. | |
model (chainer.Chain): Chainer model. | |
""" | |
import chainer | |
if "snapshot" in os.path.basename(path): | |
chainer.serializers.load_npz(path, model, path="updater/model:main/") | |
else: | |
chainer.serializers.load_npz(path, model) | |
def torch_save(path, model): | |
"""Save torch model states. | |
Args: | |
path (str): Model path to be saved. | |
model (torch.nn.Module): Torch model. | |
""" | |
if hasattr(model, "module"): | |
torch.save(model.module.state_dict(), path) | |
else: | |
torch.save(model.state_dict(), path) | |
def snapshot_object(target, filename): | |
"""Returns a trainer extension to take snapshots of a given object. | |
Args: | |
target (model): Object to serialize. | |
filename (str): Name of the file into which the object is serialized.It can | |
be a format string, where the trainer object is passed to | |
the :meth: `str.format` method. For example, | |
``'snapshot_{.updater.iteration}'`` is converted to | |
``'snapshot_10000'`` at the 10,000th iteration. | |
Returns: | |
An extension function. | |
""" | |
from chainer.training import extension | |
def snapshot_object(trainer): | |
torch_save(os.path.join(trainer.out, filename.format(trainer)), target) | |
return snapshot_object | |
def torch_load(path, model): | |
"""Load torch model states. | |
Args: | |
path (str): Model path or snapshot file path to be loaded. | |
model (torch.nn.Module): Torch model. | |
""" | |
if "snapshot" in os.path.basename(path): | |
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[ | |
"model" | |
] | |
else: | |
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage) | |
if hasattr(model, "module"): | |
model.module.load_state_dict(model_state_dict) | |
else: | |
model.load_state_dict(model_state_dict) | |
del model_state_dict | |
def torch_resume(snapshot_path, trainer): | |
"""Resume from snapshot for pytorch. | |
Args: | |
snapshot_path (str): Snapshot file path. | |
trainer (chainer.training.Trainer): Chainer's trainer instance. | |
""" | |
from chainer.serializers import NpzDeserializer | |
# load snapshot | |
snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage) | |
# restore trainer states | |
d = NpzDeserializer(snapshot_dict["trainer"]) | |
d.load(trainer) | |
# restore model states | |
if hasattr(trainer.updater.model, "model"): | |
# (for TTS model) | |
if hasattr(trainer.updater.model.model, "module"): | |
trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"]) | |
else: | |
trainer.updater.model.model.load_state_dict(snapshot_dict["model"]) | |
else: | |
# (for ASR model) | |
if hasattr(trainer.updater.model, "module"): | |
trainer.updater.model.module.load_state_dict(snapshot_dict["model"]) | |
else: | |
trainer.updater.model.load_state_dict(snapshot_dict["model"]) | |
# retore optimizer states | |
trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"]) | |
# delete opened snapshot | |
del snapshot_dict | |
# * ------------------ recognition related ------------------ * | |
def parse_hypothesis(hyp, char_list): | |
"""Parse hypothesis. | |
Args: | |
hyp (list[dict[str, Any]]): Recognition hypothesis. | |
char_list (list[str]): List of characters. | |
Returns: | |
tuple(str, str, str, float) | |
""" | |
# remove sos and get results | |
tokenid_as_list = list(map(int, hyp["yseq"][1:])) | |
token_as_list = [char_list[idx] for idx in tokenid_as_list] | |
score = float(hyp["score"]) | |
# convert to string | |
tokenid = " ".join([str(idx) for idx in tokenid_as_list]) | |
token = " ".join(token_as_list) | |
text = "".join(token_as_list).replace("<space>", " ") | |
return text, token, tokenid, score | |
def add_results_to_json(nbest_hyps, char_list): | |
"""Add N-best results to json. | |
Args: | |
js (dict[str, Any]): Groundtruth utterance dict. | |
nbest_hyps_sd (list[dict[str, Any]]): | |
List of hypothesis for multi_speakers: nutts x nspkrs. | |
char_list (list[str]): List of characters. | |
Returns: | |
str: 1-best result | |
""" | |
assert len(nbest_hyps) == 1, "only 1-best result is supported." | |
# parse hypothesis | |
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(nbest_hyps[0], char_list) | |
return rec_text | |
def plot_spectrogram( | |
plt, | |
spec, | |
mode="db", | |
fs=None, | |
frame_shift=None, | |
bottom=True, | |
left=True, | |
right=True, | |
top=False, | |
labelbottom=True, | |
labelleft=True, | |
labelright=True, | |
labeltop=False, | |
cmap="inferno", | |
): | |
"""Plot spectrogram using matplotlib. | |
Args: | |
plt (matplotlib.pyplot): pyplot object. | |
spec (numpy.ndarray): Input stft (Freq, Time) | |
mode (str): db or linear. | |
fs (int): Sample frequency. To convert y-axis to kHz unit. | |
frame_shift (int): The frame shift of stft. To convert x-axis to second unit. | |
bottom (bool):Whether to draw the respective ticks. | |
left (bool): | |
right (bool): | |
top (bool): | |
labelbottom (bool):Whether to draw the respective tick labels. | |
labelleft (bool): | |
labelright (bool): | |
labeltop (bool): | |
cmap (str): Colormap defined in matplotlib. | |
""" | |
spec = np.abs(spec) | |
if mode == "db": | |
x = 20 * np.log10(spec + np.finfo(spec.dtype).eps) | |
elif mode == "linear": | |
x = spec | |
else: | |
raise ValueError(mode) | |
if fs is not None: | |
ytop = fs / 2000 | |
ylabel = "kHz" | |
else: | |
ytop = x.shape[0] | |
ylabel = "bin" | |
if frame_shift is not None and fs is not None: | |
xtop = x.shape[1] * frame_shift / fs | |
xlabel = "s" | |
else: | |
xtop = x.shape[1] | |
xlabel = "frame" | |
extent = (0, xtop, 0, ytop) | |
plt.imshow(x[::-1], cmap=cmap, extent=extent) | |
if labelbottom: | |
plt.xlabel("time [{}]".format(xlabel)) | |
if labelleft: | |
plt.ylabel("freq [{}]".format(ylabel)) | |
plt.colorbar().set_label("{}".format(mode)) | |
plt.tick_params( | |
bottom=bottom, | |
left=left, | |
right=right, | |
top=top, | |
labelbottom=labelbottom, | |
labelleft=labelleft, | |
labelright=labelright, | |
labeltop=labeltop, | |
) | |
plt.axis("auto") | |
# * ------------------ recognition related ------------------ * | |
def format_mulenc_args(args): | |
"""Format args for multi-encoder setup. | |
It deals with following situations: (when args.num_encs=2): | |
1. args.elayers = None -> args.elayers = [4, 4]; | |
2. args.elayers = 4 -> args.elayers = [4, 4]; | |
3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4]. | |
""" | |
# default values when None is assigned. | |
default_dict = { | |
"etype": "blstmp", | |
"elayers": 4, | |
"eunits": 300, | |
"subsample": "1", | |
"dropout_rate": 0.0, | |
"atype": "dot", | |
"adim": 320, | |
"awin": 5, | |
"aheads": 4, | |
"aconv_chans": -1, | |
"aconv_filts": 100, | |
} | |
for k in default_dict.keys(): | |
if isinstance(vars(args)[k], list): | |
if len(vars(args)[k]) != args.num_encs: | |
logging.warning( | |
"Length mismatch {}: Convert {} to {}.".format( | |
k, vars(args)[k], vars(args)[k][: args.num_encs] | |
) | |
) | |
vars(args)[k] = vars(args)[k][: args.num_encs] | |
else: | |
if not vars(args)[k]: | |
# assign default value if it is None | |
vars(args)[k] = default_dict[k] | |
logging.warning( | |
"{} is not specified, use default value {}.".format( | |
k, default_dict[k] | |
) | |
) | |
# duplicate | |
logging.warning( | |
"Type mismatch {}: Convert {} to {}.".format( | |
k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)] | |
) | |
) | |
vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)] | |
return args | |