|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import numpy as np |
|
import six |
|
|
|
|
|
class CTCPrefixScoreTH(object): |
|
"""Batch processing of CTCPrefixScore |
|
|
|
which is based on Algorithm 2 in WATANABE et al. |
|
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," |
|
but extended to efficiently compute the label probablities for multiple |
|
hypotheses simultaneously |
|
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based |
|
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019. |
|
""" |
|
|
|
def __init__(self, x, xlens, blank, eos, margin=0): |
|
"""Construct CTC prefix scorer |
|
|
|
:param torch.Tensor x: input label posterior sequences (B, T, O) |
|
:param torch.Tensor xlens: input lengths (B,) |
|
:param int blank: blank label id |
|
:param int eos: end-of-sequence id |
|
:param int margin: margin parameter for windowing (0 means no windowing) |
|
""" |
|
|
|
|
|
self.logzero = -10000000000.0 |
|
self.blank = blank |
|
self.eos = eos |
|
self.batch = x.size(0) |
|
self.input_length = x.size(1) |
|
self.odim = x.size(2) |
|
self.dtype = x.dtype |
|
self.device = ( |
|
torch.device("cuda:%d" % x.get_device()) |
|
if x.is_cuda |
|
else torch.device("cpu") |
|
) |
|
|
|
|
|
for i, l in enumerate(xlens): |
|
if l < self.input_length: |
|
x[i, l:, :] = self.logzero |
|
x[i, l:, blank] = 0 |
|
|
|
xn = x.transpose(0, 1) |
|
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) |
|
self.x = torch.stack([xn, xb]) |
|
self.end_frames = torch.as_tensor(xlens) - 1 |
|
|
|
|
|
self.margin = margin |
|
if margin > 0: |
|
self.frame_ids = torch.arange( |
|
self.input_length, dtype=self.dtype, device=self.device |
|
) |
|
|
|
self.idx_bh = None |
|
self.idx_b = torch.arange(self.batch, device=self.device) |
|
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1) |
|
|
|
def __call__(self, y, state, scoring_ids=None, att_w=None): |
|
"""Compute CTC prefix scores for next labels |
|
|
|
:param list y: prefix label sequences |
|
:param tuple state: previous CTC state |
|
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O) |
|
:param torch.Tensor att_w: attention weights to decide CTC window |
|
:return new_state, ctc_local_scores (BW, O) |
|
""" |
|
output_length = len(y[0]) - 1 |
|
last_ids = [yi[-1] for yi in y] |
|
n_bh = len(last_ids) |
|
n_hyps = n_bh // self.batch |
|
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0 |
|
|
|
if state is None: |
|
r_prev = torch.full( |
|
(self.input_length, 2, self.batch, n_hyps), |
|
self.logzero, |
|
dtype=self.dtype, |
|
device=self.device, |
|
) |
|
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) |
|
r_prev = r_prev.view(-1, 2, n_bh) |
|
s_prev = 0.0 |
|
f_min_prev = 0 |
|
f_max_prev = 1 |
|
else: |
|
r_prev, s_prev, f_min_prev, f_max_prev = state |
|
|
|
|
|
if self.scoring_num > 0: |
|
scoring_idmap = torch.full( |
|
(n_bh, self.odim), -1, dtype=torch.long, device=self.device |
|
) |
|
snum = self.scoring_num |
|
if self.idx_bh is None or n_bh > len(self.idx_bh): |
|
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1) |
|
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange( |
|
snum, device=self.device |
|
) |
|
scoring_idx = ( |
|
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) |
|
).view(-1) |
|
x_ = torch.index_select( |
|
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx |
|
).view(2, -1, n_bh, snum) |
|
else: |
|
scoring_ids = None |
|
scoring_idmap = None |
|
snum = self.odim |
|
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum) |
|
|
|
|
|
|
|
r = torch.full( |
|
(self.input_length, 2, n_bh, snum), |
|
self.logzero, |
|
dtype=self.dtype, |
|
device=self.device, |
|
) |
|
if output_length == 0: |
|
r[0, 0] = x_[0, 0] |
|
|
|
r_sum = torch.logsumexp(r_prev, 1) |
|
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) |
|
if scoring_ids is not None: |
|
for idx in range(n_bh): |
|
pos = scoring_idmap[idx, last_ids[idx]] |
|
if pos >= 0: |
|
log_phi[:, idx, pos] = r_prev[:, 1, idx] |
|
else: |
|
for idx in range(n_bh): |
|
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx] |
|
|
|
|
|
if att_w is not None and self.margin > 0: |
|
f_arg = torch.matmul(att_w, self.frame_ids) |
|
f_min = max(int(f_arg.min().cpu()), f_min_prev) |
|
f_max = max(int(f_arg.max().cpu()), f_max_prev) |
|
start = min(f_max_prev, max(f_min - self.margin, output_length, 1)) |
|
end = min(f_max + self.margin, self.input_length) |
|
else: |
|
f_min = f_max = 0 |
|
start = max(output_length, 1) |
|
end = self.input_length |
|
|
|
|
|
for t in range(start, end): |
|
rp = r[t - 1] |
|
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( |
|
2, 2, n_bh, snum |
|
) |
|
r[t] = torch.logsumexp(rr, 1) + x_[:, t] |
|
|
|
|
|
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0] |
|
if scoring_ids is not None: |
|
log_psi = torch.full( |
|
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device |
|
) |
|
log_psi_ = torch.logsumexp( |
|
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), |
|
dim=0, |
|
) |
|
for si in range(n_bh): |
|
log_psi[si, scoring_ids[si]] = log_psi_[si] |
|
else: |
|
log_psi = torch.logsumexp( |
|
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), |
|
dim=0, |
|
) |
|
|
|
for si in range(n_bh): |
|
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si] |
|
|
|
|
|
log_psi[:, self.blank] = self.logzero |
|
|
|
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap) |
|
|
|
def index_select_state(self, state, best_ids): |
|
"""Select CTC states according to best ids |
|
|
|
:param state : CTC state |
|
:param best_ids : index numbers selected by beam pruning (B, W) |
|
:return selected_state |
|
""" |
|
r, s, f_min, f_max, scoring_idmap = state |
|
|
|
n_bh = len(s) |
|
n_hyps = n_bh // self.batch |
|
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1) |
|
|
|
s_new = torch.index_select(s.view(-1), 0, vidx) |
|
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) |
|
|
|
if scoring_idmap is not None: |
|
snum = self.scoring_num |
|
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( |
|
-1 |
|
) |
|
label_ids = torch.fmod(best_ids, self.odim).view(-1) |
|
score_idx = scoring_idmap[hyp_idx, label_ids] |
|
score_idx[score_idx == -1] = 0 |
|
vidx = score_idx + hyp_idx * snum |
|
else: |
|
snum = self.odim |
|
|
|
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view( |
|
-1, 2, n_bh |
|
) |
|
return r_new, s_new, f_min, f_max |
|
|
|
def extend_prob(self, x): |
|
"""Extend CTC prob. |
|
|
|
:param torch.Tensor x: input label posterior sequences (B, T, O) |
|
""" |
|
|
|
if self.x.shape[1] < x.shape[1]: |
|
|
|
|
|
xlens = [x.size(1)] |
|
for i, l in enumerate(xlens): |
|
if l < self.input_length: |
|
x[i, l:, :] = self.logzero |
|
x[i, l:, self.blank] = 0 |
|
tmp_x = self.x |
|
xn = x.transpose(0, 1) |
|
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) |
|
self.x = torch.stack([xn, xb]) |
|
self.x[:, : tmp_x.shape[1], :, :] = tmp_x |
|
self.input_length = x.size(1) |
|
self.end_frames = torch.as_tensor(xlens) - 1 |
|
|
|
def extend_state(self, state): |
|
"""Compute CTC prefix state. |
|
|
|
|
|
:param state : CTC state |
|
:return ctc_state |
|
""" |
|
|
|
if state is None: |
|
|
|
return state |
|
else: |
|
r_prev, s_prev, f_min_prev, f_max_prev = state |
|
|
|
r_prev_new = torch.full( |
|
(self.input_length, 2), |
|
self.logzero, |
|
dtype=self.dtype, |
|
device=self.device, |
|
) |
|
start = max(r_prev.shape[0], 1) |
|
r_prev_new[0:start] = r_prev |
|
for t in six.moves.range(start, self.input_length): |
|
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank] |
|
|
|
return (r_prev_new, s_prev, f_min_prev, f_max_prev) |
|
|
|
|
|
class CTCPrefixScore(object): |
|
"""Compute CTC label sequence scores |
|
|
|
which is based on Algorithm 2 in WATANABE et al. |
|
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," |
|
but extended to efficiently compute the probablities of multiple labels |
|
simultaneously |
|
""" |
|
|
|
def __init__(self, x, blank, eos, xp): |
|
self.xp = xp |
|
self.logzero = -10000000000.0 |
|
self.blank = blank |
|
self.eos = eos |
|
self.input_length = len(x) |
|
self.x = x |
|
|
|
def initial_state(self): |
|
"""Obtain an initial CTC state |
|
|
|
:return: CTC state |
|
""" |
|
|
|
|
|
|
|
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32) |
|
r[0, 1] = self.x[0, self.blank] |
|
for i in six.moves.range(1, self.input_length): |
|
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank] |
|
return r |
|
|
|
def __call__(self, y, cs, r_prev): |
|
"""Compute CTC prefix scores for next labels |
|
|
|
:param y : prefix label sequence |
|
:param cs : array of next labels |
|
:param r_prev: previous CTC state |
|
:return ctc_scores, ctc_states |
|
""" |
|
|
|
output_length = len(y) - 1 |
|
|
|
|
|
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32) |
|
xs = self.x[:, cs] |
|
if output_length == 0: |
|
r[0, 0] = xs[0] |
|
r[0, 1] = self.logzero |
|
else: |
|
r[output_length - 1] = self.logzero |
|
|
|
|
|
r_sum = self.xp.logaddexp( |
|
r_prev[:, 0], r_prev[:, 1] |
|
) |
|
last = y[-1] |
|
if output_length > 0 and last in cs: |
|
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32) |
|
for i in six.moves.range(len(cs)): |
|
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1] |
|
else: |
|
log_phi = r_sum |
|
|
|
|
|
|
|
start = max(output_length, 1) |
|
log_psi = r[start - 1, 0] |
|
for t in six.moves.range(start, self.input_length): |
|
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t] |
|
r[t, 1] = ( |
|
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank] |
|
) |
|
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t]) |
|
|
|
|
|
eos_pos = self.xp.where(cs == self.eos)[0] |
|
if len(eos_pos) > 0: |
|
log_psi[eos_pos] = r_sum[-1] |
|
|
|
|
|
blank_pos = self.xp.where(cs == self.blank)[0] |
|
if len(blank_pos) > 0: |
|
log_psi[blank_pos] = self.logzero |
|
|
|
|
|
|
|
return log_psi, self.xp.rollaxis(r, 2) |
|
|