Spaces:
Running
Running
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torchaudio | |
from torch import nn | |
from torch.nn.utils.rnn import pad_sequence | |
from modules.wenet_extractor.transducer.predictor import PredictorBase | |
from modules.wenet_extractor.transducer.search.greedy_search import basic_greedy_search | |
from modules.wenet_extractor.transducer.search.prefix_beam_search import ( | |
PrefixBeamSearch, | |
) | |
from modules.wenet_extractor.transformer.asr_model import ASRModel | |
from modules.wenet_extractor.transformer.ctc import CTC | |
from modules.wenet_extractor.transformer.decoder import ( | |
BiTransformerDecoder, | |
TransformerDecoder, | |
) | |
from modules.wenet_extractor.transformer.label_smoothing_loss import LabelSmoothingLoss | |
from modules.wenet_extractor.utils.common import ( | |
IGNORE_ID, | |
add_blank, | |
add_sos_eos, | |
reverse_pad_list, | |
) | |
class Transducer(ASRModel): | |
"""Transducer-ctc-attention hybrid Encoder-Predictor-Decoder model""" | |
def __init__( | |
self, | |
vocab_size: int, | |
blank: int, | |
encoder: nn.Module, | |
predictor: PredictorBase, | |
joint: nn.Module, | |
attention_decoder: Optional[ | |
Union[TransformerDecoder, BiTransformerDecoder] | |
] = None, | |
ctc: Optional[CTC] = None, | |
ctc_weight: float = 0, | |
ignore_id: int = IGNORE_ID, | |
reverse_weight: float = 0.0, | |
lsm_weight: float = 0.0, | |
length_normalized_loss: bool = False, | |
transducer_weight: float = 1.0, | |
attention_weight: float = 0.0, | |
) -> None: | |
assert attention_weight + ctc_weight + transducer_weight == 1.0 | |
super().__init__( | |
vocab_size, | |
encoder, | |
attention_decoder, | |
ctc, | |
ctc_weight, | |
ignore_id, | |
reverse_weight, | |
lsm_weight, | |
length_normalized_loss, | |
) | |
self.blank = blank | |
self.transducer_weight = transducer_weight | |
self.attention_decoder_weight = 1 - self.transducer_weight - self.ctc_weight | |
self.predictor = predictor | |
self.joint = joint | |
self.bs = None | |
# Note(Mddct): decoder also means predictor in transducer, | |
# but here decoder is attention decoder | |
del self.criterion_att | |
if attention_decoder is not None: | |
self.criterion_att = LabelSmoothingLoss( | |
size=vocab_size, | |
padding_idx=ignore_id, | |
smoothing=lsm_weight, | |
normalize_length=length_normalized_loss, | |
) | |
def forward( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
) -> Dict[str, Optional[torch.Tensor]]: | |
"""Frontend + Encoder + predictor + joint + loss | |
Args: | |
speech: (Batch, Length, ...) | |
speech_lengths: (Batch, ) | |
text: (Batch, Length) | |
text_lengths: (Batch,) | |
""" | |
assert text_lengths.dim() == 1, text_lengths.shape | |
# Check that batch_size is unified | |
assert ( | |
speech.shape[0] | |
== speech_lengths.shape[0] | |
== text.shape[0] | |
== text_lengths.shape[0] | |
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) | |
# Encoder | |
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
# predictor | |
ys_in_pad = add_blank(text, self.blank, self.ignore_id) | |
predictor_out = self.predictor(ys_in_pad) | |
# joint | |
joint_out = self.joint(encoder_out, predictor_out) | |
# NOTE(Mddct): some loss implementation require pad valid is zero | |
# torch.int32 rnnt_loss required | |
rnnt_text = text.to(torch.int64) | |
rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to( | |
torch.int32 | |
) | |
rnnt_text_lengths = text_lengths.to(torch.int32) | |
encoder_out_lens = encoder_out_lens.to(torch.int32) | |
loss = torchaudio.functional.rnnt_loss( | |
joint_out, | |
rnnt_text, | |
encoder_out_lens, | |
rnnt_text_lengths, | |
blank=self.blank, | |
reduction="mean", | |
) | |
loss_rnnt = loss | |
loss = self.transducer_weight * loss | |
# optional attention decoder | |
loss_att: Optional[torch.Tensor] = None | |
if self.attention_decoder_weight != 0.0 and self.decoder is not None: | |
loss_att, _ = self._calc_att_loss( | |
encoder_out, encoder_mask, text, text_lengths | |
) | |
# optional ctc | |
loss_ctc: Optional[torch.Tensor] = None | |
if self.ctc_weight != 0.0 and self.ctc is not None: | |
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) | |
else: | |
loss_ctc = None | |
if loss_ctc is not None: | |
loss = loss + self.ctc_weight * loss_ctc.sum() | |
if loss_att is not None: | |
loss = loss + self.attention_decoder_weight * loss_att.sum() | |
# NOTE: 'loss' must be in dict | |
return { | |
"loss": loss, | |
"loss_att": loss_att, | |
"loss_ctc": loss_ctc, | |
"loss_rnnt": loss_rnnt, | |
} | |
def init_bs(self): | |
if self.bs is None: | |
self.bs = PrefixBeamSearch( | |
self.encoder, self.predictor, self.joint, self.ctc, self.blank | |
) | |
def _cal_transducer_score( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_mask: torch.Tensor, | |
hyps_lens: torch.Tensor, | |
hyps_pad: torch.Tensor, | |
): | |
# ignore id -> blank, add blank at head | |
hyps_pad_blank = add_blank(hyps_pad, self.blank, self.ignore_id) | |
xs_in_lens = encoder_mask.squeeze(1).sum(1).int() | |
# 1. Forward predictor | |
predictor_out = self.predictor(hyps_pad_blank) | |
# 2. Forward joint | |
joint_out = self.joint(encoder_out, predictor_out) | |
rnnt_text = hyps_pad.to(torch.int64) | |
rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to( | |
torch.int32 | |
) | |
# 3. Compute transducer loss | |
loss_td = torchaudio.functional.rnnt_loss( | |
joint_out, | |
rnnt_text, | |
xs_in_lens, | |
hyps_lens.int(), | |
blank=self.blank, | |
reduction="none", | |
) | |
return loss_td * -1 | |
def _cal_attn_score( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_mask: torch.Tensor, | |
hyps_pad: torch.Tensor, | |
hyps_lens: torch.Tensor, | |
): | |
# (beam_size, max_hyps_len) | |
ori_hyps_pad = hyps_pad | |
# td_score = loss_td * -1 | |
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) | |
hyps_lens = hyps_lens + 1 # Add <sos> at begining | |
# used for right to left decoder | |
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) | |
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id) | |
decoder_out, r_decoder_out, _ = self.decoder( | |
encoder_out, | |
encoder_mask, | |
hyps_pad, | |
hyps_lens, | |
r_hyps_pad, | |
self.reverse_weight, | |
) # (beam_size, max_hyps_len, vocab_size) | |
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) | |
decoder_out = decoder_out.cpu().numpy() | |
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a | |
# conventional transformer decoder. | |
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) | |
r_decoder_out = r_decoder_out.cpu().numpy() | |
return decoder_out, r_decoder_out | |
def beam_search( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = -1, | |
beam_size: int = 5, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
ctc_weight: float = 0.3, | |
transducer_weight: float = 0.7, | |
): | |
"""beam search | |
Args: | |
speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
speech_length (torch.Tensor): (batch, ) | |
beam_size (int): beam size for beam search | |
decoding_chunk_size (int): decoding chunk for dynamic chunk | |
trained model. | |
<0: for decoding, use full chunk. | |
>0: for decoding, use fixed chunk size as set. | |
0: used for training, it's prohibited here | |
simulate_streaming (bool): whether do encoder forward in a | |
streaming fashion | |
ctc_weight (float): ctc probability weight in transducer | |
prefix beam search. | |
final_prob = ctc_weight * ctc_prob + transducer_weight * transducer_prob | |
transducer_weight (float): transducer probability weight in | |
prefix beam search | |
Returns: | |
List[List[int]]: best path result | |
""" | |
self.init_bs() | |
beam, _ = self.bs.prefix_beam_search( | |
speech, | |
speech_lengths, | |
decoding_chunk_size, | |
beam_size, | |
num_decoding_left_chunks, | |
simulate_streaming, | |
ctc_weight, | |
transducer_weight, | |
) | |
return beam[0].hyp[1:], beam[0].score | |
def transducer_attention_rescoring( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
beam_size: int, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
reverse_weight: float = 0.0, | |
ctc_weight: float = 0.0, | |
attn_weight: float = 0.0, | |
transducer_weight: float = 0.0, | |
search_ctc_weight: float = 1.0, | |
search_transducer_weight: float = 0.0, | |
beam_search_type: str = "transducer", | |
) -> List[List[int]]: | |
"""beam search | |
Args: | |
speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
speech_length (torch.Tensor): (batch, ) | |
beam_size (int): beam size for beam search | |
decoding_chunk_size (int): decoding chunk for dynamic chunk | |
trained model. | |
<0: for decoding, use full chunk. | |
>0: for decoding, use fixed chunk size as set. | |
0: used for training, it's prohibited here | |
simulate_streaming (bool): whether do encoder forward in a | |
streaming fashion | |
ctc_weight (float): ctc probability weight using in rescoring. | |
rescore_prob = ctc_weight * ctc_prob + | |
transducer_weight * (transducer_loss * -1) + | |
attn_weight * attn_prob | |
attn_weight (float): attn probability weight using in rescoring. | |
transducer_weight (float): transducer probability weight using in | |
rescoring | |
search_ctc_weight (float): ctc weight using | |
in rnnt beam search (seeing in self.beam_search) | |
search_transducer_weight (float): transducer weight using | |
in rnnt beam search (seeing in self.beam_search) | |
Returns: | |
List[List[int]]: best path result | |
""" | |
assert speech.shape[0] == speech_lengths.shape[0] | |
assert decoding_chunk_size != 0 | |
if reverse_weight > 0.0: | |
# decoder should be a bitransformer decoder if reverse_weight > 0.0 | |
assert hasattr(self.decoder, "right_decoder") | |
device = speech.device | |
batch_size = speech.shape[0] | |
# For attention rescoring we only support batch_size=1 | |
assert batch_size == 1 | |
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size | |
self.init_bs() | |
if beam_search_type == "transducer": | |
beam, encoder_out = self.bs.prefix_beam_search( | |
speech, | |
speech_lengths, | |
decoding_chunk_size=decoding_chunk_size, | |
beam_size=beam_size, | |
num_decoding_left_chunks=num_decoding_left_chunks, | |
ctc_weight=search_ctc_weight, | |
transducer_weight=search_transducer_weight, | |
) | |
beam_score = [s.score for s in beam] | |
hyps = [s.hyp[1:] for s in beam] | |
elif beam_search_type == "ctc": | |
hyps, encoder_out = self._ctc_prefix_beam_search( | |
speech, | |
speech_lengths, | |
beam_size=beam_size, | |
decoding_chunk_size=decoding_chunk_size, | |
num_decoding_left_chunks=num_decoding_left_chunks, | |
simulate_streaming=simulate_streaming, | |
) | |
beam_score = [hyp[1] for hyp in hyps] | |
hyps = [hyp[0] for hyp in hyps] | |
assert len(hyps) == beam_size | |
# build hyps and encoder output | |
hyps_pad = pad_sequence( | |
[torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps], | |
True, | |
self.ignore_id, | |
) # (beam_size, max_hyps_len) | |
hyps_lens = torch.tensor( | |
[len(hyp) for hyp in hyps], device=device, dtype=torch.long | |
) # (beam_size,) | |
encoder_out = encoder_out.repeat(beam_size, 1, 1) | |
encoder_mask = torch.ones( | |
beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device | |
) | |
# 2.1 calculate transducer score | |
td_score = self._cal_transducer_score( | |
encoder_out, | |
encoder_mask, | |
hyps_lens, | |
hyps_pad, | |
) | |
# 2.2 calculate attention score | |
decoder_out, r_decoder_out = self._cal_attn_score( | |
encoder_out, | |
encoder_mask, | |
hyps_pad, | |
hyps_lens, | |
) | |
# Only use decoder score for rescoring | |
best_score = -float("inf") | |
best_index = 0 | |
for i, hyp in enumerate(hyps): | |
score = 0.0 | |
for j, w in enumerate(hyp): | |
score += decoder_out[i][j][w] | |
score += decoder_out[i][len(hyp)][self.eos] | |
td_s = td_score[i] | |
# add right to left decoder score | |
if reverse_weight > 0: | |
r_score = 0.0 | |
for j, w in enumerate(hyp): | |
r_score += r_decoder_out[i][len(hyp) - j - 1][w] | |
r_score += r_decoder_out[i][len(hyp)][self.eos] | |
score = score * (1 - reverse_weight) + r_score * reverse_weight | |
# add ctc score | |
score = ( | |
score * attn_weight | |
+ beam_score[i] * ctc_weight | |
+ td_s * transducer_weight | |
) | |
if score > best_score: | |
best_score = score | |
best_index = i | |
return hyps[best_index], best_score | |
def greedy_search( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
n_steps: int = 64, | |
) -> List[List[int]]: | |
"""greedy search | |
Args: | |
speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
speech_length (torch.Tensor): (batch, ) | |
beam_size (int): beam size for beam search | |
decoding_chunk_size (int): decoding chunk for dynamic chunk | |
trained model. | |
<0: for decoding, use full chunk. | |
>0: for decoding, use fixed chunk size as set. | |
0: used for training, it's prohibited here | |
simulate_streaming (bool): whether do encoder forward in a | |
streaming fashion | |
Returns: | |
List[List[int]]: best path result | |
""" | |
# TODO(Mddct): batch decode | |
assert speech.size(0) == 1 | |
assert speech.shape[0] == speech_lengths.shape[0] | |
assert decoding_chunk_size != 0 | |
# TODO(Mddct): forward chunk by chunk | |
_ = simulate_streaming | |
# Let's assume B = batch_size | |
encoder_out, encoder_mask = self.encoder( | |
speech, | |
speech_lengths, | |
decoding_chunk_size, | |
num_decoding_left_chunks, | |
) | |
encoder_out_lens = encoder_mask.squeeze(1).sum() | |
hyps = basic_greedy_search(self, encoder_out, encoder_out_lens, n_steps=n_steps) | |
return hyps | |
def forward_encoder_chunk( | |
self, | |
xs: torch.Tensor, | |
offset: int, | |
required_cache_size: int, | |
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
return self.encoder.forward_chunk( | |
xs, offset, required_cache_size, att_cache, cnn_cache | |
) | |
def forward_predictor_step( | |
self, xs: torch.Tensor, cache: List[torch.Tensor] | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
assert len(cache) == 2 | |
# fake padding | |
padding = torch.zeros(1, 1) | |
return self.predictor.forward_step(xs, padding, cache) | |
def forward_joint_step( | |
self, enc_out: torch.Tensor, pred_out: torch.Tensor | |
) -> torch.Tensor: | |
return self.joint(enc_out, pred_out) | |
def forward_predictor_init_state(self) -> List[torch.Tensor]: | |
return self.predictor.init_state(1, device=torch.device("cpu")) | |