|
import logging |
|
import random |
|
import six |
|
|
|
import chainer |
|
import chainer.functions as F |
|
import chainer.links as L |
|
import numpy as np |
|
|
|
import espnet.nets.chainer_backend.deterministic_embed_id as DL |
|
|
|
from argparse import Namespace |
|
|
|
from espnet.nets.ctc_prefix_score import CTCPrefixScore |
|
from espnet.nets.e2e_asr_common import end_detect |
|
|
|
CTC_SCORING_RATIO = 1.5 |
|
MAX_DECODER_OUTPUT = 5 |
|
|
|
|
|
class Decoder(chainer.Chain): |
|
"""Decoder layer. |
|
|
|
Args: |
|
eprojs (int): Dimension of input variables from encoder. |
|
odim (int): The output dimension. |
|
dtype (str): Decoder type. |
|
dlayers (int): Number of layers for decoder. |
|
dunits (int): Dimension of input vector of decoder. |
|
sos (int): Number to indicate the start of sequences. |
|
eos (int): Number to indicate the end of sequences. |
|
att (Module): Attention module defined at |
|
`espnet.espnet.nets.chainer_backend.attentions`. |
|
verbose (int): Verbosity level. |
|
char_list (List[str]): List of all charactors. |
|
labeldist (numpy.array): Distributed array of counted transcript length. |
|
lsm_weight (float): Weight to use when calculating the training loss. |
|
sampling_probability (float): Threshold for scheduled sampling. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
eprojs, |
|
odim, |
|
dtype, |
|
dlayers, |
|
dunits, |
|
sos, |
|
eos, |
|
att, |
|
verbose=0, |
|
char_list=None, |
|
labeldist=None, |
|
lsm_weight=0.0, |
|
sampling_probability=0.0, |
|
): |
|
super(Decoder, self).__init__() |
|
with self.init_scope(): |
|
self.embed = DL.EmbedID(odim, dunits) |
|
self.rnn0 = ( |
|
L.StatelessLSTM(dunits + eprojs, dunits) |
|
if dtype == "lstm" |
|
else L.StatelessGRU(dunits + eprojs, dunits) |
|
) |
|
for i in six.moves.range(1, dlayers): |
|
setattr( |
|
self, |
|
"rnn%d" % i, |
|
L.StatelessLSTM(dunits, dunits) |
|
if dtype == "lstm" |
|
else L.StatelessGRU(dunits, dunits), |
|
) |
|
self.output = L.Linear(dunits, odim) |
|
self.dtype = dtype |
|
self.loss = None |
|
self.att = att |
|
self.dlayers = dlayers |
|
self.dunits = dunits |
|
self.sos = sos |
|
self.eos = eos |
|
self.verbose = verbose |
|
self.char_list = char_list |
|
|
|
self.labeldist = labeldist |
|
self.vlabeldist = None |
|
self.lsm_weight = lsm_weight |
|
self.sampling_probability = sampling_probability |
|
|
|
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): |
|
if self.dtype == "lstm": |
|
c_list[0], z_list[0] = self.rnn0(c_prev[0], z_prev[0], ey) |
|
for i in six.moves.range(1, self.dlayers): |
|
c_list[i], z_list[i] = self["rnn%d" % i]( |
|
c_prev[i], z_prev[i], z_list[i - 1] |
|
) |
|
else: |
|
if z_prev[0] is None: |
|
xp = self.xp |
|
with chainer.backends.cuda.get_device_from_id(self._device_id): |
|
z_prev[0] = chainer.Variable( |
|
xp.zeros((ey.shape[0], self.dunits), dtype=ey.dtype) |
|
) |
|
z_list[0] = self.rnn0(z_prev[0], ey) |
|
for i in six.moves.range(1, self.dlayers): |
|
if z_prev[i] is None: |
|
xp = self.xp |
|
with chainer.backends.cuda.get_device_from_id(self._device_id): |
|
z_prev[i] = chainer.Variable( |
|
xp.zeros( |
|
(z_list[i - 1].shape[0], self.dunits), |
|
dtype=z_list[i - 1].dtype, |
|
) |
|
) |
|
z_list[i] = self["rnn%d" % i](z_prev[i], z_list[i - 1]) |
|
return z_list, c_list |
|
|
|
def __call__(self, hs, ys): |
|
"""Core function of Decoder layer. |
|
|
|
Args: |
|
hs (list of chainer.Variable | N-dimension array): |
|
Input variable from encoder. |
|
ys (list of chainer.Variable | N-dimension array): |
|
Input variable of decoder. |
|
|
|
Returns: |
|
chainer.Variable: A variable holding a scalar array of the training loss. |
|
chainer.Variable: A variable holding a scalar array of the accuracy. |
|
|
|
""" |
|
self.loss = None |
|
|
|
eos = self.xp.array([self.eos], "i") |
|
sos = self.xp.array([self.sos], "i") |
|
ys_in = [F.concat([sos, y], axis=0) for y in ys] |
|
ys_out = [F.concat([y, eos], axis=0) for y in ys] |
|
|
|
|
|
|
|
pad_ys_in = F.pad_sequence(ys_in, padding=self.eos) |
|
pad_ys_out = F.pad_sequence(ys_out, padding=-1) |
|
|
|
|
|
batch = pad_ys_out.shape[0] |
|
olength = pad_ys_out.shape[1] |
|
logging.info( |
|
self.__class__.__name__ |
|
+ " input lengths: " |
|
+ str(self.xp.array([h.shape[0] for h in hs])) |
|
) |
|
logging.info( |
|
self.__class__.__name__ |
|
+ " output lengths: " |
|
+ str(self.xp.array([y.shape[0] for y in ys_out])) |
|
) |
|
|
|
|
|
c_list = [None] |
|
z_list = [None] |
|
for _ in six.moves.range(1, self.dlayers): |
|
c_list.append(None) |
|
z_list.append(None) |
|
att_w = None |
|
z_all = [] |
|
self.att.reset() |
|
|
|
|
|
eys = self.embed(pad_ys_in) |
|
eys = F.separate(eys, axis=1) |
|
|
|
|
|
for i in six.moves.range(olength): |
|
att_c, att_w = self.att(hs, z_list[0], att_w) |
|
if i > 0 and random.random() < self.sampling_probability: |
|
logging.info(" scheduled sampling ") |
|
z_out = self.output(z_all[-1]) |
|
z_out = F.argmax(F.log_softmax(z_out), axis=1) |
|
z_out = self.embed(z_out) |
|
ey = F.hstack((z_out, att_c)) |
|
else: |
|
ey = F.hstack((eys[i], att_c)) |
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) |
|
z_all.append(z_list[-1]) |
|
|
|
z_all = F.stack(z_all, axis=1).reshape(batch * olength, self.dunits) |
|
|
|
y_all = self.output(z_all) |
|
self.loss = F.softmax_cross_entropy(y_all, F.flatten(pad_ys_out)) |
|
|
|
self.loss *= np.mean([len(x) for x in ys_in]) - 1 |
|
acc = F.accuracy(y_all, F.flatten(pad_ys_out), ignore_label=-1) |
|
logging.info("att loss:" + str(self.loss.data)) |
|
|
|
|
|
if self.verbose > 0 and self.char_list is not None: |
|
y_hat = y_all.reshape(batch, olength, -1) |
|
y_true = pad_ys_out |
|
for (i, y_hat_), y_true_ in zip(enumerate(y_hat.data), y_true.data): |
|
if i == MAX_DECODER_OUTPUT: |
|
break |
|
idx_hat = self.xp.argmax(y_hat_[y_true_ != -1], axis=1) |
|
idx_true = y_true_[y_true_ != -1] |
|
seq_hat = [self.char_list[int(idx)] for idx in idx_hat] |
|
seq_true = [self.char_list[int(idx)] for idx in idx_true] |
|
seq_hat = "".join(seq_hat).replace("<space>", " ") |
|
seq_true = "".join(seq_true).replace("<space>", " ") |
|
logging.info("groundtruth[%d]: " % i + seq_true) |
|
logging.info("prediction [%d]: " % i + seq_hat) |
|
|
|
if self.labeldist is not None: |
|
if self.vlabeldist is None: |
|
self.vlabeldist = chainer.Variable(self.xp.asarray(self.labeldist)) |
|
loss_reg = -F.sum( |
|
F.scale(F.log_softmax(y_all), self.vlabeldist, axis=1) |
|
) / len(ys_in) |
|
self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg |
|
|
|
return self.loss, acc |
|
|
|
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None): |
|
"""Beam search implementation. |
|
|
|
Args: |
|
h (chainer.Variable): One of the output from the encoder. |
|
lpz (chainer.Variable | None): Result of net propagation. |
|
recog_args (Namespace): The argument. |
|
char_list (List[str]): List of all charactors. |
|
rnnlm (Module): RNNLM module. Defined at `espnet.lm.chainer_backend.lm` |
|
|
|
Returns: |
|
List[Dict[str,Any]]: Result of recognition. |
|
|
|
""" |
|
logging.info("input lengths: " + str(h.shape[0])) |
|
|
|
c_list = [None] |
|
z_list = [None] |
|
for _ in six.moves.range(1, self.dlayers): |
|
c_list.append(None) |
|
z_list.append(None) |
|
a = None |
|
self.att.reset() |
|
|
|
|
|
beam = recog_args.beam_size |
|
penalty = recog_args.penalty |
|
ctc_weight = recog_args.ctc_weight |
|
|
|
|
|
y = self.xp.full(1, self.sos, "i") |
|
if recog_args.maxlenratio == 0: |
|
maxlen = h.shape[0] |
|
else: |
|
|
|
maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) |
|
minlen = int(recog_args.minlenratio * h.shape[0]) |
|
logging.info("max output length: " + str(maxlen)) |
|
logging.info("min output length: " + str(minlen)) |
|
|
|
|
|
if rnnlm: |
|
hyp = { |
|
"score": 0.0, |
|
"yseq": [y], |
|
"c_prev": c_list, |
|
"z_prev": z_list, |
|
"a_prev": a, |
|
"rnnlm_prev": None, |
|
} |
|
else: |
|
hyp = { |
|
"score": 0.0, |
|
"yseq": [y], |
|
"c_prev": c_list, |
|
"z_prev": z_list, |
|
"a_prev": a, |
|
} |
|
if lpz is not None: |
|
ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) |
|
hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() |
|
hyp["ctc_score_prev"] = 0.0 |
|
if ctc_weight != 1.0: |
|
|
|
ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) |
|
else: |
|
ctc_beam = lpz.shape[-1] |
|
hyps = [hyp] |
|
ended_hyps = [] |
|
|
|
for i in six.moves.range(maxlen): |
|
logging.debug("position " + str(i)) |
|
|
|
hyps_best_kept = [] |
|
for hyp in hyps: |
|
ey = self.embed(hyp["yseq"][i]) |
|
att_c, att_w = self.att([h], hyp["z_prev"][0], hyp["a_prev"]) |
|
ey = F.hstack((ey, att_c)) |
|
|
|
z_list, c_list = self.rnn_forward( |
|
ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"] |
|
) |
|
|
|
|
|
local_att_scores = F.log_softmax(self.output(z_list[-1])).data |
|
if rnnlm: |
|
rnnlm_state, local_lm_scores = rnnlm.predict( |
|
hyp["rnnlm_prev"], hyp["yseq"][i] |
|
) |
|
local_scores = ( |
|
local_att_scores + recog_args.lm_weight * local_lm_scores |
|
) |
|
else: |
|
local_scores = local_att_scores |
|
|
|
if lpz is not None: |
|
local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ |
|
:ctc_beam |
|
] |
|
ctc_scores, ctc_states = ctc_prefix_score( |
|
hyp["yseq"], local_best_ids, hyp["ctc_state_prev"] |
|
) |
|
local_scores = (1.0 - ctc_weight) * local_att_scores[ |
|
:, local_best_ids |
|
] + ctc_weight * (ctc_scores - hyp["ctc_score_prev"]) |
|
if rnnlm: |
|
local_scores += ( |
|
recog_args.lm_weight * local_lm_scores[:, local_best_ids] |
|
) |
|
joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ |
|
:beam |
|
] |
|
local_best_scores = local_scores[:, joint_best_ids] |
|
local_best_ids = local_best_ids[joint_best_ids] |
|
else: |
|
local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ |
|
:beam |
|
] |
|
local_best_scores = local_scores[:, local_best_ids] |
|
|
|
for j in six.moves.range(beam): |
|
new_hyp = {} |
|
|
|
new_hyp["z_prev"] = z_list[:] |
|
new_hyp["c_prev"] = c_list[:] |
|
new_hyp["a_prev"] = att_w |
|
new_hyp["score"] = hyp["score"] + local_best_scores[0, j] |
|
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) |
|
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] |
|
new_hyp["yseq"][len(hyp["yseq"])] = self.xp.full( |
|
1, local_best_ids[j], "i" |
|
) |
|
if rnnlm: |
|
new_hyp["rnnlm_prev"] = rnnlm_state |
|
if lpz is not None: |
|
new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[j]] |
|
new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[j]] |
|
|
|
hyps_best_kept.append(new_hyp) |
|
|
|
hyps_best_kept = sorted( |
|
hyps_best_kept, key=lambda x: x["score"], reverse=True |
|
)[:beam] |
|
|
|
|
|
hyps = hyps_best_kept |
|
logging.debug("number of pruned hypotheses: " + str(len(hyps))) |
|
logging.debug( |
|
"best hypo: " |
|
+ "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]).replace( |
|
"<space>", " " |
|
) |
|
) |
|
|
|
|
|
if i == maxlen - 1: |
|
logging.info("adding <eos> in the last position in the loop") |
|
for hyp in hyps: |
|
hyp["yseq"].append(self.xp.full(1, self.eos, "i")) |
|
|
|
|
|
|
|
|
|
remained_hyps = [] |
|
for hyp in hyps: |
|
if hyp["yseq"][-1] == self.eos: |
|
|
|
|
|
if len(hyp["yseq"]) > minlen: |
|
hyp["score"] += (i + 1) * penalty |
|
if rnnlm: |
|
hyp["score"] += recog_args.lm_weight * rnnlm.final( |
|
hyp["rnnlm_prev"] |
|
) |
|
ended_hyps.append(hyp) |
|
else: |
|
remained_hyps.append(hyp) |
|
|
|
|
|
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: |
|
logging.info("end detected at %d", i) |
|
break |
|
|
|
hyps = remained_hyps |
|
if len(hyps) > 0: |
|
logging.debug("remaining hypotheses: " + str(len(hyps))) |
|
else: |
|
logging.info("no hypothesis. Finish decoding.") |
|
break |
|
|
|
for hyp in hyps: |
|
logging.debug( |
|
"hypo: " |
|
+ "".join([char_list[int(x)] for x in hyp["yseq"][1:]]).replace( |
|
"<space>", " " |
|
) |
|
) |
|
|
|
logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) |
|
|
|
nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ |
|
: min(len(ended_hyps), recog_args.nbest) |
|
] |
|
|
|
|
|
if len(nbest_hyps) == 0: |
|
logging.warning( |
|
"there is no N-best results, " |
|
"perform recognition again with smaller minlenratio." |
|
) |
|
|
|
recog_args = Namespace(**vars(recog_args)) |
|
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) |
|
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) |
|
|
|
logging.info("total log probability: " + str(nbest_hyps[0]["score"])) |
|
logging.info( |
|
"normalized log probability: " |
|
+ str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) |
|
) |
|
|
|
return nbest_hyps |
|
|
|
def calculate_all_attentions(self, hs, ys): |
|
"""Calculate all of attentions. |
|
|
|
Args: |
|
hs (list of chainer.Variable | N-dimensional array): |
|
Input variable from encoder. |
|
ys (list of chainer.Variable | N-dimensional array): |
|
Input variable of decoder. |
|
|
|
Returns: |
|
chainer.Variable: List of attention weights. |
|
|
|
""" |
|
|
|
eos = self.xp.array([self.eos], "i") |
|
sos = self.xp.array([self.sos], "i") |
|
ys_in = [F.concat([sos, y], axis=0) for y in ys] |
|
ys_out = [F.concat([y, eos], axis=0) for y in ys] |
|
|
|
|
|
|
|
pad_ys_in = F.pad_sequence(ys_in, padding=self.eos) |
|
pad_ys_out = F.pad_sequence(ys_out, padding=-1) |
|
|
|
|
|
olength = pad_ys_out.shape[1] |
|
|
|
|
|
c_list = [None] |
|
z_list = [None] |
|
for _ in six.moves.range(1, self.dlayers): |
|
c_list.append(None) |
|
z_list.append(None) |
|
att_w = None |
|
att_ws = [] |
|
self.att.reset() |
|
|
|
|
|
eys = self.embed(pad_ys_in) |
|
eys = F.separate(eys, axis=1) |
|
|
|
|
|
for i in six.moves.range(olength): |
|
att_c, att_w = self.att(hs, z_list[0], att_w) |
|
ey = F.hstack((eys[i], att_c)) |
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) |
|
att_ws.append(att_w) |
|
|
|
att_ws = F.stack(att_ws, axis=1) |
|
att_ws.to_cpu() |
|
|
|
return att_ws.data |
|
|
|
|
|
def decoder_for(args, odim, sos, eos, att, labeldist): |
|
"""Return the decoding layer corresponding to the args. |
|
|
|
Args: |
|
args (Namespace): The program arguments. |
|
odim (int): The output dimension. |
|
sos (int): Number to indicate the start of sequences. |
|
eos (int) Number to indicate the end of sequences. |
|
att (Module): |
|
Attention module defined at `espnet.nets.chainer_backend.attentions`. |
|
labeldist (numpy.array): Distributed array of length od transcript. |
|
|
|
Returns: |
|
chainer.Chain: The decoder module. |
|
|
|
""" |
|
return Decoder( |
|
args.eprojs, |
|
odim, |
|
args.dtype, |
|
args.dlayers, |
|
args.dunits, |
|
sos, |
|
eos, |
|
att, |
|
args.verbose, |
|
args.char_list, |
|
labeldist, |
|
args.lsm_weight, |
|
args.sampling_probability, |
|
) |
|
|