|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torchaudio |
|
from torch import nn |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from modules.wenet_extractor.transducer.predictor import PredictorBase |
|
from modules.wenet_extractor.transducer.search.greedy_search import basic_greedy_search |
|
from modules.wenet_extractor.transducer.search.prefix_beam_search import ( |
|
PrefixBeamSearch, |
|
) |
|
from modules.wenet_extractor.transformer.asr_model import ASRModel |
|
from modules.wenet_extractor.transformer.ctc import CTC |
|
from modules.wenet_extractor.transformer.decoder import ( |
|
BiTransformerDecoder, |
|
TransformerDecoder, |
|
) |
|
from modules.wenet_extractor.transformer.label_smoothing_loss import LabelSmoothingLoss |
|
from modules.wenet_extractor.utils.common import ( |
|
IGNORE_ID, |
|
add_blank, |
|
add_sos_eos, |
|
reverse_pad_list, |
|
) |
|
|
|
|
|
class Transducer(ASRModel): |
|
"""Transducer-ctc-attention hybrid Encoder-Predictor-Decoder model""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
blank: int, |
|
encoder: nn.Module, |
|
predictor: PredictorBase, |
|
joint: nn.Module, |
|
attention_decoder: Optional[ |
|
Union[TransformerDecoder, BiTransformerDecoder] |
|
] = None, |
|
ctc: Optional[CTC] = None, |
|
ctc_weight: float = 0, |
|
ignore_id: int = IGNORE_ID, |
|
reverse_weight: float = 0.0, |
|
lsm_weight: float = 0.0, |
|
length_normalized_loss: bool = False, |
|
transducer_weight: float = 1.0, |
|
attention_weight: float = 0.0, |
|
) -> None: |
|
assert attention_weight + ctc_weight + transducer_weight == 1.0 |
|
super().__init__( |
|
vocab_size, |
|
encoder, |
|
attention_decoder, |
|
ctc, |
|
ctc_weight, |
|
ignore_id, |
|
reverse_weight, |
|
lsm_weight, |
|
length_normalized_loss, |
|
) |
|
|
|
self.blank = blank |
|
self.transducer_weight = transducer_weight |
|
self.attention_decoder_weight = 1 - self.transducer_weight - self.ctc_weight |
|
|
|
self.predictor = predictor |
|
self.joint = joint |
|
self.bs = None |
|
|
|
|
|
|
|
del self.criterion_att |
|
if attention_decoder is not None: |
|
self.criterion_att = LabelSmoothingLoss( |
|
size=vocab_size, |
|
padding_idx=ignore_id, |
|
smoothing=lsm_weight, |
|
normalize_length=length_normalized_loss, |
|
) |
|
|
|
def forward( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
text: torch.Tensor, |
|
text_lengths: torch.Tensor, |
|
) -> Dict[str, Optional[torch.Tensor]]: |
|
"""Frontend + Encoder + predictor + joint + loss |
|
|
|
Args: |
|
speech: (Batch, Length, ...) |
|
speech_lengths: (Batch, ) |
|
text: (Batch, Length) |
|
text_lengths: (Batch,) |
|
""" |
|
assert text_lengths.dim() == 1, text_lengths.shape |
|
|
|
assert ( |
|
speech.shape[0] |
|
== speech_lengths.shape[0] |
|
== text.shape[0] |
|
== text_lengths.shape[0] |
|
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
|
|
|
|
|
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) |
|
encoder_out_lens = encoder_mask.squeeze(1).sum(1) |
|
|
|
ys_in_pad = add_blank(text, self.blank, self.ignore_id) |
|
predictor_out = self.predictor(ys_in_pad) |
|
|
|
joint_out = self.joint(encoder_out, predictor_out) |
|
|
|
|
|
rnnt_text = text.to(torch.int64) |
|
rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to( |
|
torch.int32 |
|
) |
|
rnnt_text_lengths = text_lengths.to(torch.int32) |
|
encoder_out_lens = encoder_out_lens.to(torch.int32) |
|
loss = torchaudio.functional.rnnt_loss( |
|
joint_out, |
|
rnnt_text, |
|
encoder_out_lens, |
|
rnnt_text_lengths, |
|
blank=self.blank, |
|
reduction="mean", |
|
) |
|
loss_rnnt = loss |
|
|
|
loss = self.transducer_weight * loss |
|
|
|
loss_att: Optional[torch.Tensor] = None |
|
if self.attention_decoder_weight != 0.0 and self.decoder is not None: |
|
loss_att, _ = self._calc_att_loss( |
|
encoder_out, encoder_mask, text, text_lengths |
|
) |
|
|
|
|
|
loss_ctc: Optional[torch.Tensor] = None |
|
if self.ctc_weight != 0.0 and self.ctc is not None: |
|
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) |
|
else: |
|
loss_ctc = None |
|
|
|
if loss_ctc is not None: |
|
loss = loss + self.ctc_weight * loss_ctc.sum() |
|
if loss_att is not None: |
|
loss = loss + self.attention_decoder_weight * loss_att.sum() |
|
|
|
return { |
|
"loss": loss, |
|
"loss_att": loss_att, |
|
"loss_ctc": loss_ctc, |
|
"loss_rnnt": loss_rnnt, |
|
} |
|
|
|
def init_bs(self): |
|
if self.bs is None: |
|
self.bs = PrefixBeamSearch( |
|
self.encoder, self.predictor, self.joint, self.ctc, self.blank |
|
) |
|
|
|
def _cal_transducer_score( |
|
self, |
|
encoder_out: torch.Tensor, |
|
encoder_mask: torch.Tensor, |
|
hyps_lens: torch.Tensor, |
|
hyps_pad: torch.Tensor, |
|
): |
|
|
|
hyps_pad_blank = add_blank(hyps_pad, self.blank, self.ignore_id) |
|
xs_in_lens = encoder_mask.squeeze(1).sum(1).int() |
|
|
|
|
|
predictor_out = self.predictor(hyps_pad_blank) |
|
|
|
joint_out = self.joint(encoder_out, predictor_out) |
|
rnnt_text = hyps_pad.to(torch.int64) |
|
rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to( |
|
torch.int32 |
|
) |
|
|
|
loss_td = torchaudio.functional.rnnt_loss( |
|
joint_out, |
|
rnnt_text, |
|
xs_in_lens, |
|
hyps_lens.int(), |
|
blank=self.blank, |
|
reduction="none", |
|
) |
|
return loss_td * -1 |
|
|
|
def _cal_attn_score( |
|
self, |
|
encoder_out: torch.Tensor, |
|
encoder_mask: torch.Tensor, |
|
hyps_pad: torch.Tensor, |
|
hyps_lens: torch.Tensor, |
|
): |
|
|
|
ori_hyps_pad = hyps_pad |
|
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) |
|
hyps_lens = hyps_lens + 1 |
|
|
|
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) |
|
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id) |
|
decoder_out, r_decoder_out, _ = self.decoder( |
|
encoder_out, |
|
encoder_mask, |
|
hyps_pad, |
|
hyps_lens, |
|
r_hyps_pad, |
|
self.reverse_weight, |
|
) |
|
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) |
|
decoder_out = decoder_out.cpu().numpy() |
|
|
|
|
|
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) |
|
r_decoder_out = r_decoder_out.cpu().numpy() |
|
return decoder_out, r_decoder_out |
|
|
|
def beam_search( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
decoding_chunk_size: int = -1, |
|
beam_size: int = 5, |
|
num_decoding_left_chunks: int = -1, |
|
simulate_streaming: bool = False, |
|
ctc_weight: float = 0.3, |
|
transducer_weight: float = 0.7, |
|
): |
|
"""beam search |
|
|
|
Args: |
|
speech (torch.Tensor): (batch=1, max_len, feat_dim) |
|
speech_length (torch.Tensor): (batch, ) |
|
beam_size (int): beam size for beam search |
|
decoding_chunk_size (int): decoding chunk for dynamic chunk |
|
trained model. |
|
<0: for decoding, use full chunk. |
|
>0: for decoding, use fixed chunk size as set. |
|
0: used for training, it's prohibited here |
|
simulate_streaming (bool): whether do encoder forward in a |
|
streaming fashion |
|
ctc_weight (float): ctc probability weight in transducer |
|
prefix beam search. |
|
final_prob = ctc_weight * ctc_prob + transducer_weight * transducer_prob |
|
transducer_weight (float): transducer probability weight in |
|
prefix beam search |
|
Returns: |
|
List[List[int]]: best path result |
|
|
|
""" |
|
self.init_bs() |
|
beam, _ = self.bs.prefix_beam_search( |
|
speech, |
|
speech_lengths, |
|
decoding_chunk_size, |
|
beam_size, |
|
num_decoding_left_chunks, |
|
simulate_streaming, |
|
ctc_weight, |
|
transducer_weight, |
|
) |
|
return beam[0].hyp[1:], beam[0].score |
|
|
|
def transducer_attention_rescoring( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
beam_size: int, |
|
decoding_chunk_size: int = -1, |
|
num_decoding_left_chunks: int = -1, |
|
simulate_streaming: bool = False, |
|
reverse_weight: float = 0.0, |
|
ctc_weight: float = 0.0, |
|
attn_weight: float = 0.0, |
|
transducer_weight: float = 0.0, |
|
search_ctc_weight: float = 1.0, |
|
search_transducer_weight: float = 0.0, |
|
beam_search_type: str = "transducer", |
|
) -> List[List[int]]: |
|
"""beam search |
|
|
|
Args: |
|
speech (torch.Tensor): (batch=1, max_len, feat_dim) |
|
speech_length (torch.Tensor): (batch, ) |
|
beam_size (int): beam size for beam search |
|
decoding_chunk_size (int): decoding chunk for dynamic chunk |
|
trained model. |
|
<0: for decoding, use full chunk. |
|
>0: for decoding, use fixed chunk size as set. |
|
0: used for training, it's prohibited here |
|
simulate_streaming (bool): whether do encoder forward in a |
|
streaming fashion |
|
ctc_weight (float): ctc probability weight using in rescoring. |
|
rescore_prob = ctc_weight * ctc_prob + |
|
transducer_weight * (transducer_loss * -1) + |
|
attn_weight * attn_prob |
|
attn_weight (float): attn probability weight using in rescoring. |
|
transducer_weight (float): transducer probability weight using in |
|
rescoring |
|
search_ctc_weight (float): ctc weight using |
|
in rnnt beam search (seeing in self.beam_search) |
|
search_transducer_weight (float): transducer weight using |
|
in rnnt beam search (seeing in self.beam_search) |
|
Returns: |
|
List[List[int]]: best path result |
|
|
|
""" |
|
|
|
assert speech.shape[0] == speech_lengths.shape[0] |
|
assert decoding_chunk_size != 0 |
|
if reverse_weight > 0.0: |
|
|
|
assert hasattr(self.decoder, "right_decoder") |
|
device = speech.device |
|
batch_size = speech.shape[0] |
|
|
|
assert batch_size == 1 |
|
|
|
self.init_bs() |
|
if beam_search_type == "transducer": |
|
beam, encoder_out = self.bs.prefix_beam_search( |
|
speech, |
|
speech_lengths, |
|
decoding_chunk_size=decoding_chunk_size, |
|
beam_size=beam_size, |
|
num_decoding_left_chunks=num_decoding_left_chunks, |
|
ctc_weight=search_ctc_weight, |
|
transducer_weight=search_transducer_weight, |
|
) |
|
beam_score = [s.score for s in beam] |
|
hyps = [s.hyp[1:] for s in beam] |
|
|
|
elif beam_search_type == "ctc": |
|
hyps, encoder_out = self._ctc_prefix_beam_search( |
|
speech, |
|
speech_lengths, |
|
beam_size=beam_size, |
|
decoding_chunk_size=decoding_chunk_size, |
|
num_decoding_left_chunks=num_decoding_left_chunks, |
|
simulate_streaming=simulate_streaming, |
|
) |
|
beam_score = [hyp[1] for hyp in hyps] |
|
hyps = [hyp[0] for hyp in hyps] |
|
assert len(hyps) == beam_size |
|
|
|
|
|
hyps_pad = pad_sequence( |
|
[torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps], |
|
True, |
|
self.ignore_id, |
|
) |
|
hyps_lens = torch.tensor( |
|
[len(hyp) for hyp in hyps], device=device, dtype=torch.long |
|
) |
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1) |
|
encoder_mask = torch.ones( |
|
beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device |
|
) |
|
|
|
|
|
td_score = self._cal_transducer_score( |
|
encoder_out, |
|
encoder_mask, |
|
hyps_lens, |
|
hyps_pad, |
|
) |
|
|
|
decoder_out, r_decoder_out = self._cal_attn_score( |
|
encoder_out, |
|
encoder_mask, |
|
hyps_pad, |
|
hyps_lens, |
|
) |
|
|
|
|
|
best_score = -float("inf") |
|
best_index = 0 |
|
for i, hyp in enumerate(hyps): |
|
score = 0.0 |
|
for j, w in enumerate(hyp): |
|
score += decoder_out[i][j][w] |
|
score += decoder_out[i][len(hyp)][self.eos] |
|
td_s = td_score[i] |
|
|
|
if reverse_weight > 0: |
|
r_score = 0.0 |
|
for j, w in enumerate(hyp): |
|
r_score += r_decoder_out[i][len(hyp) - j - 1][w] |
|
r_score += r_decoder_out[i][len(hyp)][self.eos] |
|
score = score * (1 - reverse_weight) + r_score * reverse_weight |
|
|
|
score = ( |
|
score * attn_weight |
|
+ beam_score[i] * ctc_weight |
|
+ td_s * transducer_weight |
|
) |
|
if score > best_score: |
|
best_score = score |
|
best_index = i |
|
|
|
return hyps[best_index], best_score |
|
|
|
def greedy_search( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
decoding_chunk_size: int = -1, |
|
num_decoding_left_chunks: int = -1, |
|
simulate_streaming: bool = False, |
|
n_steps: int = 64, |
|
) -> List[List[int]]: |
|
"""greedy search |
|
|
|
Args: |
|
speech (torch.Tensor): (batch=1, max_len, feat_dim) |
|
speech_length (torch.Tensor): (batch, ) |
|
beam_size (int): beam size for beam search |
|
decoding_chunk_size (int): decoding chunk for dynamic chunk |
|
trained model. |
|
<0: for decoding, use full chunk. |
|
>0: for decoding, use fixed chunk size as set. |
|
0: used for training, it's prohibited here |
|
simulate_streaming (bool): whether do encoder forward in a |
|
streaming fashion |
|
Returns: |
|
List[List[int]]: best path result |
|
""" |
|
|
|
assert speech.size(0) == 1 |
|
assert speech.shape[0] == speech_lengths.shape[0] |
|
assert decoding_chunk_size != 0 |
|
|
|
_ = simulate_streaming |
|
|
|
encoder_out, encoder_mask = self.encoder( |
|
speech, |
|
speech_lengths, |
|
decoding_chunk_size, |
|
num_decoding_left_chunks, |
|
) |
|
encoder_out_lens = encoder_mask.squeeze(1).sum() |
|
hyps = basic_greedy_search(self, encoder_out, encoder_out_lens, n_steps=n_steps) |
|
|
|
return hyps |
|
|
|
@torch.jit.export |
|
def forward_encoder_chunk( |
|
self, |
|
xs: torch.Tensor, |
|
offset: int, |
|
required_cache_size: int, |
|
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), |
|
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
return self.encoder.forward_chunk( |
|
xs, offset, required_cache_size, att_cache, cnn_cache |
|
) |
|
|
|
@torch.jit.export |
|
def forward_predictor_step( |
|
self, xs: torch.Tensor, cache: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
assert len(cache) == 2 |
|
|
|
padding = torch.zeros(1, 1) |
|
return self.predictor.forward_step(xs, padding, cache) |
|
|
|
@torch.jit.export |
|
def forward_joint_step( |
|
self, enc_out: torch.Tensor, pred_out: torch.Tensor |
|
) -> torch.Tensor: |
|
return self.joint(enc_out, pred_out) |
|
|
|
@torch.jit.export |
|
def forward_predictor_init_state(self) -> List[torch.Tensor]: |
|
return self.predictor.init_state(1, device=torch.device("cpu")) |
|
|