Spaces:
Runtime error
Runtime error
File size: 4,851 Bytes
09481f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
import matplotlib.pyplot as plt
import numpy
from espnet.asr import asr_utils
def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
# dynamically import matplotlib due to not found error
from matplotlib.ticker import MaxNLocator
import os
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
w, h = plt.figaspect(1.0 / len(att_w))
fig = plt.Figure(figsize=(w * 2, h * 2))
axes = fig.subplots(1, len(att_w))
if len(att_w) == 1:
axes = [axes]
for ax, aw in zip(axes, att_w):
# plt.subplot(1, len(att_w), h)
ax.imshow(aw.astype(numpy.float32), aspect="auto")
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
# Labels for major ticks
if xtokens is not None:
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens)))
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True)
ax.set_xticklabels(xtokens + [""], rotation=40)
if ytokens is not None:
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens)))
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True)
ax.set_yticklabels(ytokens + [""])
fig.tight_layout()
return fig
def savefig(plot, filename):
plot.savefig(filename)
plt.clf()
def plot_multi_head_attention(
data,
attn_dict,
outdir,
suffix="png",
savefn=savefig,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
):
"""Plot multi head attentions.
:param dict data: utts info from json file
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
values should be torch.Tensor (head, input_length, output_length)
:param str outdir: dir to save fig
:param str suffix: filename suffix including image type (e.g., png)
:param savefn: function to save
"""
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.%s.%s" % (outdir, data[idx][0], name, suffix)
dec_len = int(data[idx][1][okey][oaxis]["shape"][0])
enc_len = int(data[idx][1][ikey][iaxis]["shape"][0])
xtokens, ytokens = None, None
if "encoder" in name:
att_w = att_w[:, :enc_len, :enc_len]
# for MT
if "token" in data[idx][1][ikey][iaxis].keys():
xtokens = data[idx][1][ikey][iaxis]["token"].split()
ytokens = xtokens[:]
elif "decoder" in name:
if "self" in name:
att_w = att_w[:, : dec_len + 1, : dec_len + 1] # +1 for <sos>
else:
att_w = att_w[:, : dec_len + 1, :enc_len] # +1 for <sos>
# for MT
if "token" in data[idx][1][ikey][iaxis].keys():
xtokens = data[idx][1][ikey][iaxis]["token"].split()
# for ASR/ST/MT
if "token" in data[idx][1][okey][oaxis].keys():
ytokens = ["<sos>"] + data[idx][1][okey][oaxis]["token"].split()
if "self" in name:
xtokens = ytokens[:]
else:
logging.warning("unknown name for shaping attention")
fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens)
savefn(fig, filename)
class PlotAttentionReport(asr_utils.PlotAttentionReport):
def plotfn(self, *args, **kwargs):
kwargs["ikey"] = self.ikey
kwargs["iaxis"] = self.iaxis
kwargs["okey"] = self.okey
kwargs["oaxis"] = self.oaxis
plot_multi_head_attention(*args, **kwargs)
def __call__(self, trainer):
attn_dict = self.get_attention_weights()
suffix = "ep.{.updater.epoch}.png".format(trainer)
self.plotfn(self.data, attn_dict, self.outdir, suffix, savefig)
def get_attention_weights(self):
batch = self.converter([self.transform(self.data)], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
elif isinstance(batch, dict):
att_ws = self.att_vis_fn(**batch)
return att_ws
def log_attentions(self, logger, step):
def log_fig(plot, filename):
from os.path import basename
logger.add_figure(basename(filename), plot, step)
plt.clf()
attn_dict = self.get_attention_weights()
self.plotfn(self.data, attn_dict, self.outdir, "", log_fig)
|