zyingt's picture
Upload 685 files
0d80816
raw
history blame
7.82 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
"""Unility functions for Transformer."""
import math
from typing import List, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
IGNORE_ID = -1
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max([x.size(0) for x in xs])
pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
pad = pad.fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def add_blank(ys_pad: torch.Tensor, blank: int, ignore_id: int) -> torch.Tensor:
"""Prepad blank for transducer predictor
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
blank (int): index of <blank>
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> blank = 0
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in = add_blank(ys_pad, 0, -1)
>>> ys_in
tensor([[0, 1, 2, 3, 4, 5],
[0, 4, 5, 6, 0, 0],
[0, 7, 8, 9, 0, 0]])
"""
bs = ys_pad.size(0)
_blank = torch.tensor(
[blank], dtype=torch.long, requires_grad=False, device=ys_pad.device
)
_blank = _blank.repeat(bs).unsqueeze(1) # [bs,1]
out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1]
return torch.where(out == ignore_id, blank, out)
def add_sos_eos(
ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
ys_out (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = torch.tensor(
[sos], dtype=torch.long, requires_grad=False, device=ys_pad.device
)
_eos = torch.tensor(
[eos], dtype=torch.long, requires_grad=False, device=ys_pad.device
)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def reverse_pad_list(
ys_pad: torch.Tensor, ys_lens: torch.Tensor, pad_value: float = -1.0
) -> torch.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence(
[(torch.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)],
True,
pad_value,
)
return r_ys_pad
def th_accuracy(
pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int
) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def get_rnn(rnn_type: str) -> torch.nn.Module:
assert rnn_type in ["rnn", "lstm", "gru"]
if rnn_type == "rnn":
return torch.nn.RNN
elif rnn_type == "lstm":
return torch.nn.LSTM
else:
return torch.nn.GRU
def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
from modules.wenet_extractor.transformer.swish import Swish
activation_funcs = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}
return activation_funcs[act]()
def get_subsample(config):
input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
new_hyp.append(hyp[cur])
prev = cur
cur += 1
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0:
new_hyp.append(0)
cur += 1
return new_hyp
def log_add(args: List[int]) -> float:
"""
Stable log add
"""
if all(a == -float("inf") for a in args):
return -float("inf")
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp