mpc001's picture
Upload 125 files
09481f3
raw
history blame
No virus
10.6 kB
from distutils.version import LooseVersion
import logging
import numpy as np
import six
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.nets_utils import to_device
class CTC(torch.nn.Module):
"""CTC module
:param int odim: dimension of outputs
:param int eprojs: number of encoder projection units
:param float dropout_rate: dropout rate (0.0 ~ 1.0)
:param str ctc_type: builtin or warpctc
:param bool reduce: reduce the CTC loss into a scalar
"""
def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True):
super().__init__()
self.dropout_rate = dropout_rate
self.loss = None
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.probs = None # for visualization
# In case of Pytorch >= 1.7.0, CTC will be always builtin
self.ctc_type = (
ctc_type
if LooseVersion(torch.__version__) < LooseVersion("1.7.0")
else "builtin"
)
if self.ctc_type == "builtin":
reduction_type = "sum" if reduce else "none"
self.ctc_loss = torch.nn.CTCLoss(
reduction=reduction_type, zero_infinity=True
)
elif self.ctc_type == "cudnnctc":
reduction_type = "sum" if reduce else "none"
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
elif self.ctc_type == "warpctc":
import warpctc_pytorch as warp_ctc
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
elif self.ctc_type == "gtnctc":
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
self.ctc_loss = GTNCTCLossFunction.apply
else:
raise ValueError(
'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type)
)
self.ignore_id = -1
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
if self.ctc_type in ["builtin", "cudnnctc"]:
th_pred = th_pred.log_softmax(2)
# Use the deterministic CuDNN implementation of CTC loss to avoid
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
with torch.backends.cudnn.flags(deterministic=True):
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
# Batch-size average
loss = loss / th_pred.size(1)
return loss
elif self.ctc_type == "warpctc":
return self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
elif self.ctc_type == "gtnctc":
targets = [t.tolist() for t in th_target]
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
return self.ctc_loss(log_probs, targets, th_ilen, 0, "none")
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad):
"""CTC forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
"""
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
# zero padding for hs
ys_hat = self.ctc_lo(self.dropout(hs_pad))
if self.ctc_type != "gtnctc":
ys_hat = ys_hat.transpose(0, 1)
if self.ctc_type == "builtin":
olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys]))
hlens = hlens.long()
ys_pad = torch.cat(ys) # without this the code breaks for asr_mix
self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens)
else:
self.loss = None
hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
olens = torch.from_numpy(
np.fromiter((x.size(0) for x in ys), dtype=np.int32)
)
# zero padding for ys
ys_true = torch.cat(ys).cpu().int() # batch x olen
# get ctc loss
# expected shape of seqLength x batchSize x alphabet_size
dtype = ys_hat.dtype
if self.ctc_type == "warpctc" or dtype == torch.float16:
# warpctc only supports float32
# torch.ctc does not support float16 (#1751)
ys_hat = ys_hat.to(dtype=torch.float32)
if self.ctc_type == "cudnnctc":
# use GPU when using the cuDNN implementation
ys_true = to_device(hs_pad, ys_true)
if self.ctc_type == "gtnctc":
# keep as list for gtn
ys_true = ys
self.loss = to_device(
hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens)
).to(dtype=dtype)
# get length info
logging.info(
self.__class__.__name__
+ " input lengths: "
+ "".join(str(hlens).split("\n"))
)
logging.info(
self.__class__.__name__
+ " output lengths: "
+ "".join(str(olens).split("\n"))
)
if self.reduce:
# NOTE: sum() is needed to keep consistency
# since warpctc return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
self.loss = self.loss.sum()
logging.info("ctc loss:" + str(float(self.loss)))
return self.loss
def softmax(self, hs_pad):
"""softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2)
return self.probs
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: argmax applied 2d tensor (B, Tmax)
:rtype: torch.Tensor
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
def forced_align(self, h, y, blank_id=0):
"""forced alignment.
:param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
:param torch.Tensor y: id sequence tensor 1d tensor (L)
:param int y: blank symbol index
:return: best alignment results
:rtype: list
"""
def interpolate_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1)
label = label.reshape(-1)
label = np.append(label, label[0])
return label
lpz = self.log_softmax(h)
lpz = lpz.squeeze(0)
y_int = interpolate_blank(y, blank_id)
logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero
state_path = (
np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1
) # state path
logdelta[0, 0] = lpz[0][y_int[0]]
logdelta[0, 1] = lpz[0][y_int[1]]
for t in six.moves.range(1, lpz.size(0)):
for s in six.moves.range(len(y_int)):
if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]:
candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = np.array(
[
logdelta[t - 1, s],
logdelta[t - 1, s - 1],
logdelta[t - 1, s - 2],
]
)
prev_state = [s, s - 1, s - 2]
logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
state_path[t, s] = prev_state[np.argmax(candidates)]
state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)
candidates = np.array(
[logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]
)
prev_state = [len(y_int) - 1, len(y_int) - 2]
state_seq[-1] = prev_state[np.argmax(candidates)]
for t in six.moves.range(lpz.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_state_seq = []
for t in six.moves.range(0, lpz.size(0)):
output_state_seq.append(y_int[state_seq[t, 0]])
return output_state_seq
def ctc_for(args, odim, reduce=True):
"""Returns the CTC module for the given args and output dimension
:param Namespace args: the program args
:param int odim : The output dimension
:param bool reduce : return the CTC loss in a scalar
:return: the corresponding CTC module
"""
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
if num_encs == 1:
# compatible with single encoder asr mode
return CTC(
odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce
)
elif num_encs >= 1:
ctcs_list = torch.nn.ModuleList()
if args.share_ctc:
# use dropout_rate of the first encoder
ctc = CTC(
odim,
args.eprojs,
args.dropout_rate[0],
ctc_type=args.ctc_type,
reduce=reduce,
)
ctcs_list.append(ctc)
else:
for idx in range(num_encs):
ctc = CTC(
odim,
args.eprojs,
args.dropout_rate[idx],
ctc_type=args.ctc_type,
reduce=reduce,
)
ctcs_list.append(ctc)
return ctcs_list
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)