|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Decoder definition.""" |
|
from typing import Tuple, List, Optional |
|
|
|
import torch |
|
|
|
from modules.wenet_extractor.transformer.attention import MultiHeadedAttention |
|
from modules.wenet_extractor.transformer.decoder_layer import DecoderLayer |
|
from modules.wenet_extractor.transformer.embedding import PositionalEncoding |
|
from modules.wenet_extractor.transformer.embedding import NoPositionalEncoding |
|
from modules.wenet_extractor.transformer.positionwise_feed_forward import ( |
|
PositionwiseFeedForward, |
|
) |
|
from modules.wenet_extractor.utils.mask import subsequent_mask, make_pad_mask |
|
|
|
|
|
class TransformerDecoder(torch.nn.Module): |
|
"""Base class of Transfomer decoder module. |
|
Args: |
|
vocab_size: output dim |
|
encoder_output_size: dimension of attention |
|
attention_heads: the number of heads of multi head attention |
|
linear_units: the hidden units number of position-wise feedforward |
|
num_blocks: the number of decoder blocks |
|
dropout_rate: dropout rate |
|
self_attention_dropout_rate: dropout rate for attention |
|
input_layer: input layer type |
|
use_output_layer: whether to use output layer |
|
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding |
|
normalize_before: |
|
True: use layer_norm before each sub-block of a layer. |
|
False: use layer_norm after each sub-block of a layer. |
|
src_attention: if false, encoder-decoder cross attention is not |
|
applied, such as CIF model |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
encoder_output_size: int, |
|
attention_heads: int = 4, |
|
linear_units: int = 2048, |
|
num_blocks: int = 6, |
|
dropout_rate: float = 0.1, |
|
positional_dropout_rate: float = 0.1, |
|
self_attention_dropout_rate: float = 0.0, |
|
src_attention_dropout_rate: float = 0.0, |
|
input_layer: str = "embed", |
|
use_output_layer: bool = True, |
|
normalize_before: bool = True, |
|
src_attention: bool = True, |
|
): |
|
super().__init__() |
|
attention_dim = encoder_output_size |
|
|
|
if input_layer == "embed": |
|
self.embed = torch.nn.Sequential( |
|
torch.nn.Embedding(vocab_size, attention_dim), |
|
PositionalEncoding(attention_dim, positional_dropout_rate), |
|
) |
|
elif input_layer == "none": |
|
self.embed = NoPositionalEncoding(attention_dim, positional_dropout_rate) |
|
else: |
|
raise ValueError(f"only 'embed' is supported: {input_layer}") |
|
|
|
self.normalize_before = normalize_before |
|
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) |
|
self.use_output_layer = use_output_layer |
|
self.output_layer = torch.nn.Linear(attention_dim, vocab_size) |
|
self.num_blocks = num_blocks |
|
self.decoders = torch.nn.ModuleList( |
|
[ |
|
DecoderLayer( |
|
attention_dim, |
|
MultiHeadedAttention( |
|
attention_heads, attention_dim, self_attention_dropout_rate |
|
), |
|
( |
|
MultiHeadedAttention( |
|
attention_heads, attention_dim, src_attention_dropout_rate |
|
) |
|
if src_attention |
|
else None |
|
), |
|
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
|
dropout_rate, |
|
normalize_before, |
|
) |
|
for _ in range(self.num_blocks) |
|
] |
|
) |
|
|
|
def forward( |
|
self, |
|
memory: torch.Tensor, |
|
memory_mask: torch.Tensor, |
|
ys_in_pad: torch.Tensor, |
|
ys_in_lens: torch.Tensor, |
|
r_ys_in_pad: torch.Tensor = torch.empty(0), |
|
reverse_weight: float = 0.0, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Forward decoder. |
|
Args: |
|
memory: encoded memory, float32 (batch, maxlen_in, feat) |
|
memory_mask: encoder memory mask, (batch, 1, maxlen_in) |
|
ys_in_pad: padded input token ids, int64 (batch, maxlen_out) |
|
ys_in_lens: input lengths of this batch (batch) |
|
r_ys_in_pad: not used in transformer decoder, in order to unify api |
|
with bidirectional decoder |
|
reverse_weight: not used in transformer decoder, in order to unify |
|
api with bidirectional decode |
|
Returns: |
|
(tuple): tuple containing: |
|
x: decoded token score before softmax (batch, maxlen_out, |
|
vocab_size) if use_output_layer is True, |
|
torch.tensor(0.0), in order to unify api with bidirectional decoder |
|
olens: (batch, ) |
|
""" |
|
tgt = ys_in_pad |
|
maxlen = tgt.size(1) |
|
|
|
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1) |
|
tgt_mask = tgt_mask.to(tgt.device) |
|
|
|
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) |
|
|
|
tgt_mask = tgt_mask & m |
|
x, _ = self.embed(tgt) |
|
for layer in self.decoders: |
|
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, memory_mask) |
|
if self.normalize_before: |
|
x = self.after_norm(x) |
|
if self.use_output_layer: |
|
x = self.output_layer(x) |
|
olens = tgt_mask.sum(1) |
|
return x, torch.tensor(0.0), olens |
|
|
|
def forward_one_step( |
|
self, |
|
memory: torch.Tensor, |
|
memory_mask: torch.Tensor, |
|
tgt: torch.Tensor, |
|
tgt_mask: torch.Tensor, |
|
cache: Optional[List[torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
"""Forward one step. |
|
This is only used for decoding. |
|
Args: |
|
memory: encoded memory, float32 (batch, maxlen_in, feat) |
|
memory_mask: encoded memory mask, (batch, 1, maxlen_in) |
|
tgt: input token ids, int64 (batch, maxlen_out) |
|
tgt_mask: input token mask, (batch, maxlen_out) |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (include 1.2) |
|
cache: cached output list of (batch, max_time_out-1, size) |
|
Returns: |
|
y, cache: NN output value and cache per `self.decoders`. |
|
y.shape` is (batch, maxlen_out, token) |
|
""" |
|
x, _ = self.embed(tgt) |
|
new_cache = [] |
|
for i, decoder in enumerate(self.decoders): |
|
if cache is None: |
|
c = None |
|
else: |
|
c = cache[i] |
|
x, tgt_mask, memory, memory_mask = decoder( |
|
x, tgt_mask, memory, memory_mask, cache=c |
|
) |
|
new_cache.append(x) |
|
if self.normalize_before: |
|
y = self.after_norm(x[:, -1]) |
|
else: |
|
y = x[:, -1] |
|
if self.use_output_layer: |
|
y = torch.log_softmax(self.output_layer(y), dim=-1) |
|
return y, new_cache |
|
|
|
|
|
class BiTransformerDecoder(torch.nn.Module): |
|
"""Base class of Transfomer decoder module. |
|
Args: |
|
vocab_size: output dim |
|
encoder_output_size: dimension of attention |
|
attention_heads: the number of heads of multi head attention |
|
linear_units: the hidden units number of position-wise feedforward |
|
num_blocks: the number of decoder blocks |
|
r_num_blocks: the number of right to left decoder blocks |
|
dropout_rate: dropout rate |
|
self_attention_dropout_rate: dropout rate for attention |
|
input_layer: input layer type |
|
use_output_layer: whether to use output layer |
|
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding |
|
normalize_before: |
|
True: use layer_norm before each sub-block of a layer. |
|
False: use layer_norm after each sub-block of a layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
encoder_output_size: int, |
|
attention_heads: int = 4, |
|
linear_units: int = 2048, |
|
num_blocks: int = 6, |
|
r_num_blocks: int = 0, |
|
dropout_rate: float = 0.1, |
|
positional_dropout_rate: float = 0.1, |
|
self_attention_dropout_rate: float = 0.0, |
|
src_attention_dropout_rate: float = 0.0, |
|
input_layer: str = "embed", |
|
use_output_layer: bool = True, |
|
normalize_before: bool = True, |
|
): |
|
super().__init__() |
|
self.left_decoder = TransformerDecoder( |
|
vocab_size, |
|
encoder_output_size, |
|
attention_heads, |
|
linear_units, |
|
num_blocks, |
|
dropout_rate, |
|
positional_dropout_rate, |
|
self_attention_dropout_rate, |
|
src_attention_dropout_rate, |
|
input_layer, |
|
use_output_layer, |
|
normalize_before, |
|
) |
|
|
|
self.right_decoder = TransformerDecoder( |
|
vocab_size, |
|
encoder_output_size, |
|
attention_heads, |
|
linear_units, |
|
r_num_blocks, |
|
dropout_rate, |
|
positional_dropout_rate, |
|
self_attention_dropout_rate, |
|
src_attention_dropout_rate, |
|
input_layer, |
|
use_output_layer, |
|
normalize_before, |
|
) |
|
|
|
def forward( |
|
self, |
|
memory: torch.Tensor, |
|
memory_mask: torch.Tensor, |
|
ys_in_pad: torch.Tensor, |
|
ys_in_lens: torch.Tensor, |
|
r_ys_in_pad: torch.Tensor, |
|
reverse_weight: float = 0.0, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Forward decoder. |
|
Args: |
|
memory: encoded memory, float32 (batch, maxlen_in, feat) |
|
memory_mask: encoder memory mask, (batch, 1, maxlen_in) |
|
ys_in_pad: padded input token ids, int64 (batch, maxlen_out) |
|
ys_in_lens: input lengths of this batch (batch) |
|
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out), |
|
used for right to left decoder |
|
reverse_weight: used for right to left decoder |
|
Returns: |
|
(tuple): tuple containing: |
|
x: decoded token score before softmax (batch, maxlen_out, |
|
vocab_size) if use_output_layer is True, |
|
r_x: x: decoded token score (right to left decoder) |
|
before softmax (batch, maxlen_out, vocab_size) |
|
if use_output_layer is True, |
|
olens: (batch, ) |
|
""" |
|
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens) |
|
r_x = torch.tensor(0.0) |
|
if reverse_weight > 0.0: |
|
r_x, _, olens = self.right_decoder( |
|
memory, memory_mask, r_ys_in_pad, ys_in_lens |
|
) |
|
return l_x, r_x, olens |
|
|
|
def forward_one_step( |
|
self, |
|
memory: torch.Tensor, |
|
memory_mask: torch.Tensor, |
|
tgt: torch.Tensor, |
|
tgt_mask: torch.Tensor, |
|
cache: Optional[List[torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
"""Forward one step. |
|
This is only used for decoding. |
|
Args: |
|
memory: encoded memory, float32 (batch, maxlen_in, feat) |
|
memory_mask: encoded memory mask, (batch, 1, maxlen_in) |
|
tgt: input token ids, int64 (batch, maxlen_out) |
|
tgt_mask: input token mask, (batch, maxlen_out) |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (include 1.2) |
|
cache: cached output list of (batch, max_time_out-1, size) |
|
Returns: |
|
y, cache: NN output value and cache per `self.decoders`. |
|
y.shape` is (batch, maxlen_out, token) |
|
""" |
|
return self.left_decoder.forward_one_step( |
|
memory, memory_mask, tgt, tgt_mask, cache |
|
) |
|
|