|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Any, Dict, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from esm.modules import SinusoidalPositionalEmbedding |
|
from .transformer_layer import TransformerDecoderLayer |
|
|
|
|
|
def fill_with_neg_inf(t): |
|
"""FP16-compatible function that fills a tensor with -inf.""" |
|
return t.float().fill_(float("-inf")).type_as(t) |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
""" |
|
Transformer decoder consisting of *args.decoder.layers* layers. Each layer |
|
is a :class:`TransformerDecoderLayer`. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
dictionary (~fairseq.data.Dictionary): decoding dictionary |
|
embed_tokens (torch.nn.Embedding): output embedding |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs |
|
(default: False). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
args, |
|
dictionary, |
|
embed_tokens, |
|
): |
|
super().__init__() |
|
self.args = args |
|
self.dictionary = dictionary |
|
self._future_mask = torch.empty(0) |
|
|
|
self.dropout_module = nn.Dropout(args.dropout) |
|
|
|
input_embed_dim = embed_tokens.embedding_dim |
|
embed_dim = args.decoder_embed_dim |
|
self.embed_dim = embed_dim |
|
|
|
self.padding_idx = embed_tokens.padding_idx |
|
|
|
self.embed_tokens = embed_tokens |
|
self.embed_scale = math.sqrt(embed_dim) |
|
|
|
self.project_in_dim = ( |
|
nn.Linear(input_embed_dim, embed_dim, bias=False) |
|
if embed_dim != input_embed_dim |
|
else None |
|
) |
|
self.embed_positions = SinusoidalPositionalEmbedding( |
|
embed_dim, |
|
self.padding_idx, |
|
) |
|
|
|
self.layers = nn.ModuleList([]) |
|
self.layers.extend( |
|
[ |
|
self.build_decoder_layer(args) |
|
for _ in range(args.decoder_layers) |
|
] |
|
) |
|
self.num_layers = len(self.layers) |
|
self.layer_norm = nn.LayerNorm(embed_dim) |
|
|
|
self.build_output_projection(args, dictionary) |
|
|
|
def build_output_projection(self, args, dictionary): |
|
self.output_projection = nn.Linear( |
|
args.decoder_embed_dim, len(dictionary), bias=False |
|
) |
|
nn.init.normal_( |
|
self.output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 |
|
) |
|
|
|
def build_decoder_layer(self, args): |
|
return TransformerDecoderLayer(args) |
|
|
|
def forward( |
|
self, |
|
prev_output_tokens, |
|
encoder_out: Optional[Dict[str, List[Tensor]]] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
features_only: bool = False, |
|
return_all_hiddens: bool = False, |
|
): |
|
""" |
|
Args: |
|
prev_output_tokens (LongTensor): previous decoder outputs of shape |
|
`(batch, tgt_len)`, for teacher forcing |
|
encoder_out (optional): output from the encoder, used for |
|
encoder-side attention, should be of size T x B x C |
|
incremental_state (dict): dictionary used for storing state during |
|
:ref:`Incremental decoding` |
|
features_only (bool, optional): only return features without |
|
applying output layer (default: False). |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's output of shape `(batch, tgt_len, vocab)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
|
|
x, extra = self.extract_features( |
|
prev_output_tokens, |
|
encoder_out=encoder_out, |
|
incremental_state=incremental_state, |
|
) |
|
|
|
if not features_only: |
|
x = self.output_layer(x) |
|
x = x.transpose(1, 2) |
|
return x, extra |
|
|
|
def extract_features( |
|
self, |
|
prev_output_tokens, |
|
encoder_out: Optional[Dict[str, List[Tensor]]], |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
): |
|
""" |
|
Similar to *forward* but only return features. |
|
|
|
Includes several features from "Jointly Learning to Align and |
|
Translate with Transformer Models" (Garg et al., EMNLP 2019). |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
bs, slen = prev_output_tokens.size() |
|
|
|
enc: Optional[Tensor] = None |
|
padding_mask: Optional[Tensor] = None |
|
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: |
|
enc = encoder_out["encoder_out"][0] |
|
assert ( |
|
enc.size()[1] == bs |
|
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" |
|
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: |
|
padding_mask = encoder_out["encoder_padding_mask"][0] |
|
|
|
|
|
positions = self.embed_positions( |
|
prev_output_tokens |
|
) |
|
|
|
if incremental_state is not None: |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
positions = positions[:, -1:] |
|
|
|
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
|
|
if self.project_in_dim is not None: |
|
x = self.project_in_dim(x) |
|
|
|
x += positions |
|
|
|
x = self.dropout_module(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
self_attn_padding_mask: Optional[Tensor] = None |
|
if prev_output_tokens.eq(self.padding_idx).any(): |
|
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) |
|
|
|
|
|
attn: Optional[Tensor] = None |
|
inner_states: List[Optional[Tensor]] = [x] |
|
for idx, layer in enumerate(self.layers): |
|
if incremental_state is None: |
|
self_attn_mask = self.buffered_future_mask(x) |
|
else: |
|
self_attn_mask = None |
|
|
|
x, layer_attn, _ = layer( |
|
x, |
|
enc, |
|
padding_mask, |
|
incremental_state, |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask, |
|
need_attn=False, |
|
need_head_weights=False, |
|
) |
|
inner_states.append(x) |
|
|
|
if self.layer_norm is not None: |
|
x = self.layer_norm(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
return x, {"inner_states": inner_states} |
|
|
|
def output_layer(self, features): |
|
"""Project features to the vocabulary size.""" |
|
return self.output_projection(features) |
|
|
|
def buffered_future_mask(self, tensor): |
|
dim = tensor.size(0) |
|
|
|
if ( |
|
self._future_mask.size(0) == 0 |
|
or (not self._future_mask.device == tensor.device) |
|
or self._future_mask.size(0) < dim |
|
): |
|
self._future_mask = torch.triu( |
|
fill_with_neg_inf(torch.zeros([dim, dim])), 1 |
|
) |
|
self._future_mask = self._future_mask.to(tensor) |
|
return self._future_mask[:dim, :dim] |
|
|