|
import random |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from typeguard import check_argument_types |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
from espnet.nets.pytorch_backend.rnn.attentions import initial_att |
|
from espnet2.asr.decoder.abs_decoder import AbsDecoder |
|
from espnet2.utils.get_default_kwargs import get_default_kwargs |
|
|
|
|
|
def build_attention_list( |
|
eprojs: int, |
|
dunits: int, |
|
atype: str = "location", |
|
num_att: int = 1, |
|
num_encs: int = 1, |
|
aheads: int = 4, |
|
adim: int = 320, |
|
awin: int = 5, |
|
aconv_chans: int = 10, |
|
aconv_filts: int = 100, |
|
han_mode: bool = False, |
|
han_type=None, |
|
han_heads: int = 4, |
|
han_dim: int = 320, |
|
han_conv_chans: int = -1, |
|
han_conv_filts: int = 100, |
|
han_win: int = 5, |
|
): |
|
|
|
att_list = torch.nn.ModuleList() |
|
if num_encs == 1: |
|
for i in range(num_att): |
|
att = initial_att( |
|
atype, |
|
eprojs, |
|
dunits, |
|
aheads, |
|
adim, |
|
awin, |
|
aconv_chans, |
|
aconv_filts, |
|
) |
|
att_list.append(att) |
|
elif num_encs > 1: |
|
if han_mode: |
|
att = initial_att( |
|
han_type, |
|
eprojs, |
|
dunits, |
|
han_heads, |
|
han_dim, |
|
han_win, |
|
han_conv_chans, |
|
han_conv_filts, |
|
han_mode=True, |
|
) |
|
return att |
|
else: |
|
att_list = torch.nn.ModuleList() |
|
for idx in range(num_encs): |
|
att = initial_att( |
|
atype[idx], |
|
eprojs, |
|
dunits, |
|
aheads[idx], |
|
adim[idx], |
|
awin[idx], |
|
aconv_chans[idx], |
|
aconv_filts[idx], |
|
) |
|
att_list.append(att) |
|
else: |
|
raise ValueError( |
|
"Number of encoders needs to be more than one. {}".format(num_encs) |
|
) |
|
return att_list |
|
|
|
|
|
class RNNDecoder(AbsDecoder): |
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
encoder_output_size: int, |
|
rnn_type: str = "lstm", |
|
num_layers: int = 1, |
|
hidden_size: int = 320, |
|
sampling_probability: float = 0.0, |
|
dropout: float = 0.0, |
|
context_residual: bool = False, |
|
replace_sos: bool = False, |
|
num_encs: int = 1, |
|
att_conf: dict = get_default_kwargs(build_attention_list), |
|
): |
|
|
|
assert check_argument_types() |
|
if rnn_type not in {"lstm", "gru"}: |
|
raise ValueError(f"Not supported: rnn_type={rnn_type}") |
|
|
|
super().__init__() |
|
eprojs = encoder_output_size |
|
self.dtype = rnn_type |
|
self.dunits = hidden_size |
|
self.dlayers = num_layers |
|
self.context_residual = context_residual |
|
self.sos = vocab_size - 1 |
|
self.eos = vocab_size - 1 |
|
self.odim = vocab_size |
|
self.sampling_probability = sampling_probability |
|
self.dropout = dropout |
|
self.num_encs = num_encs |
|
|
|
|
|
self.replace_sos = replace_sos |
|
|
|
self.embed = torch.nn.Embedding(vocab_size, hidden_size) |
|
self.dropout_emb = torch.nn.Dropout(p=dropout) |
|
|
|
self.decoder = torch.nn.ModuleList() |
|
self.dropout_dec = torch.nn.ModuleList() |
|
self.decoder += [ |
|
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) |
|
if self.dtype == "lstm" |
|
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) |
|
] |
|
self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
|
for _ in range(1, self.dlayers): |
|
self.decoder += [ |
|
torch.nn.LSTMCell(hidden_size, hidden_size) |
|
if self.dtype == "lstm" |
|
else torch.nn.GRUCell(hidden_size, hidden_size) |
|
] |
|
self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
|
|
|
|
|
|
|
if context_residual: |
|
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size) |
|
else: |
|
self.output = torch.nn.Linear(hidden_size, vocab_size) |
|
|
|
self.att_list = build_attention_list( |
|
eprojs=eprojs, dunits=hidden_size, **att_conf |
|
) |
|
|
|
def zero_state(self, hs_pad): |
|
return hs_pad.new_zeros(hs_pad.size(0), self.dunits) |
|
|
|
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): |
|
if self.dtype == "lstm": |
|
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) |
|
for i in range(1, self.dlayers): |
|
z_list[i], c_list[i] = self.decoder[i]( |
|
self.dropout_dec[i - 1](z_list[i - 1]), |
|
(z_prev[i], c_prev[i]), |
|
) |
|
else: |
|
z_list[0] = self.decoder[0](ey, z_prev[0]) |
|
for i in range(1, self.dlayers): |
|
z_list[i] = self.decoder[i]( |
|
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] |
|
) |
|
return z_list, c_list |
|
|
|
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0): |
|
|
|
|
|
if self.num_encs == 1: |
|
hs_pad = [hs_pad] |
|
hlens = [hlens] |
|
|
|
|
|
|
|
|
|
att_idx = min(strm_idx, len(self.att_list) - 1) |
|
|
|
|
|
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] |
|
|
|
|
|
olength = ys_in_pad.size(1) |
|
|
|
|
|
c_list = [self.zero_state(hs_pad[0])] |
|
z_list = [self.zero_state(hs_pad[0])] |
|
for _ in range(1, self.dlayers): |
|
c_list.append(self.zero_state(hs_pad[0])) |
|
z_list.append(self.zero_state(hs_pad[0])) |
|
z_all = [] |
|
if self.num_encs == 1: |
|
att_w = None |
|
self.att_list[att_idx].reset() |
|
else: |
|
att_w_list = [None] * (self.num_encs + 1) |
|
att_c_list = [None] * self.num_encs |
|
for idx in range(self.num_encs + 1): |
|
|
|
self.att_list[idx].reset() |
|
|
|
|
|
eys = self.dropout_emb(self.embed(ys_in_pad)) |
|
|
|
|
|
for i in range(olength): |
|
if self.num_encs == 1: |
|
att_c, att_w = self.att_list[att_idx]( |
|
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w |
|
) |
|
else: |
|
for idx in range(self.num_encs): |
|
att_c_list[idx], att_w_list[idx] = self.att_list[idx]( |
|
hs_pad[idx], |
|
hlens[idx], |
|
self.dropout_dec[0](z_list[0]), |
|
att_w_list[idx], |
|
) |
|
hs_pad_han = torch.stack(att_c_list, dim=1) |
|
hlens_han = [self.num_encs] * len(ys_in_pad) |
|
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs]( |
|
hs_pad_han, |
|
hlens_han, |
|
self.dropout_dec[0](z_list[0]), |
|
att_w_list[self.num_encs], |
|
) |
|
if i > 0 and random.random() < self.sampling_probability: |
|
z_out = self.output(z_all[-1]) |
|
z_out = np.argmax(z_out.detach().cpu(), axis=1) |
|
z_out = self.dropout_emb(self.embed(to_device(self, z_out))) |
|
ey = torch.cat((z_out, att_c), dim=1) |
|
else: |
|
|
|
ey = torch.cat((eys[:, i, :], att_c), dim=1) |
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) |
|
if self.context_residual: |
|
z_all.append( |
|
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) |
|
) |
|
else: |
|
z_all.append(self.dropout_dec[-1](z_list[-1])) |
|
|
|
z_all = torch.stack(z_all, dim=1) |
|
z_all = self.output(z_all) |
|
z_all.masked_fill_( |
|
make_pad_mask(ys_in_lens, z_all, 1), |
|
0, |
|
) |
|
return z_all, ys_in_lens |
|
|
|
def init_state(self, x): |
|
|
|
|
|
if self.num_encs == 1: |
|
x = [x] |
|
|
|
c_list = [self.zero_state(x[0].unsqueeze(0))] |
|
z_list = [self.zero_state(x[0].unsqueeze(0))] |
|
for _ in range(1, self.dlayers): |
|
c_list.append(self.zero_state(x[0].unsqueeze(0))) |
|
z_list.append(self.zero_state(x[0].unsqueeze(0))) |
|
|
|
strm_index = 0 |
|
att_idx = min(strm_index, len(self.att_list) - 1) |
|
if self.num_encs == 1: |
|
a = None |
|
self.att_list[att_idx].reset() |
|
else: |
|
a = [None] * (self.num_encs + 1) |
|
for idx in range(self.num_encs + 1): |
|
|
|
self.att_list[idx].reset() |
|
return dict( |
|
c_prev=c_list[:], |
|
z_prev=z_list[:], |
|
a_prev=a, |
|
workspace=(att_idx, z_list, c_list), |
|
) |
|
|
|
def score(self, yseq, state, x): |
|
|
|
|
|
if self.num_encs == 1: |
|
x = [x] |
|
|
|
att_idx, z_list, c_list = state["workspace"] |
|
vy = yseq[-1].unsqueeze(0) |
|
ey = self.dropout_emb(self.embed(vy)) |
|
if self.num_encs == 1: |
|
att_c, att_w = self.att_list[att_idx]( |
|
x[0].unsqueeze(0), |
|
[x[0].size(0)], |
|
self.dropout_dec[0](state["z_prev"][0]), |
|
state["a_prev"], |
|
) |
|
else: |
|
att_w = [None] * (self.num_encs + 1) |
|
att_c_list = [None] * self.num_encs |
|
for idx in range(self.num_encs): |
|
att_c_list[idx], att_w[idx] = self.att_list[idx]( |
|
x[idx].unsqueeze(0), |
|
[x[idx].size(0)], |
|
self.dropout_dec[0](state["z_prev"][0]), |
|
state["a_prev"][idx], |
|
) |
|
h_han = torch.stack(att_c_list, dim=1) |
|
att_c, att_w[self.num_encs] = self.att_list[self.num_encs]( |
|
h_han, |
|
[self.num_encs], |
|
self.dropout_dec[0](state["z_prev"][0]), |
|
state["a_prev"][self.num_encs], |
|
) |
|
ey = torch.cat((ey, att_c), dim=1) |
|
z_list, c_list = self.rnn_forward( |
|
ey, z_list, c_list, state["z_prev"], state["c_prev"] |
|
) |
|
if self.context_residual: |
|
logits = self.output( |
|
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) |
|
) |
|
else: |
|
logits = self.output(self.dropout_dec[-1](z_list[-1])) |
|
logp = F.log_softmax(logits, dim=1).squeeze(0) |
|
return ( |
|
logp, |
|
dict( |
|
c_prev=c_list[:], |
|
z_prev=z_list[:], |
|
a_prev=att_w, |
|
workspace=(att_idx, z_list, c_list), |
|
), |
|
) |
|
|