# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Dict, List, Optional import torch import torch.nn as nn from fairseq import utils from fairseq.modules import LayerNorm from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise from torch import Tensor from .unify_multihead_attention import MultiheadAttention def drop_path(x, drop_prob: float = 0.0, training: bool = False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (1, x.shape[1], 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super().__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return "p={}".format(self.drop_prob) class TransformerEncoderLayer(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.encoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, args, drop_path_rate=0.0): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim self.quant_noise = getattr(args, 'quant_noise_pq', 0) self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu') or "relu" ) activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) self.normalize_before = args.encoder_normalize_before self.fc1 = self.build_fc1( self.embed_dim, args.encoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.fc2 = self.build_fc2( args.encoder_ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None self.nh = self.self_attn.num_heads self.head_dim = self.self_attn.head_dim self.ffn_layernorm = LayerNorm(args.encoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None self.final_layer_norm = LayerNorm(self.embed_dim) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size ) def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size ) def build_self_attention(self, embed_dim, args): return MultiheadAttention( embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, scale_factor=args.attn_scale_factor, scale_heads=getattr(args, 'scale_heads', False) ) def residual_connection(self, x, residual): return residual + self.drop_path(x) def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict(): state_dict[ "{}.{}.{}".format(name, new, m) ] = self.state_dict()["{}.{}".format(new, m)] prefix = name + "." if name != "" else "" for param_name, param_tensor in self.state_dict().items(): if (prefix + param_name) not in state_dict and param_name in self.state_dict(): state_dict[prefix + param_name] = self.state_dict()[param_name] def forward( self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None, self_attn_bias: Optional[Tensor] = None ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, seq_len)` where padding elements are indicated by ``1``. attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, where `tgt_len` is the length of output and `src_len` is the length of input, though here both are equal to `seq_len`. `attn_mask[tgt_i, src_j] = 1` means that when calculating the embedding for `tgt_i`, we exclude (mask out) `src_j`. This is useful for strided self-attention. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters if attn_mask is not None: attn_mask = attn_mask.masked_fill( attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 ) residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False, attn_mask=attn_mask, attn_bias=self_attn_bias ) if self.attn_ln is not None: x = self.attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) if self.w_resid is not None: residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) return x class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, drop_path_rate=0.0 ): super().__init__() self.embed_dim = args.decoder_embed_dim self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.quant_noise = getattr(args, "quant_noise_pq", 0) self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = self.build_self_attention( self.embed_dim, args, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) self.self_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None self.cross_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None self.nh = self.self_attn.num_heads self.head_dim = self.self_attn.head_dim self.activation_fn = utils.get_activation_fn( activation=str(args.activation_fn) if getattr(args, "activation_fn", None) is not None else "relu" ) activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.ffn_layernorm = LayerNorm(args.decoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None self.fc1 = self.build_fc1( self.embed_dim, args.decoder_ffn_embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.fc2 = self.build_fc2( args.decoder_ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_self_attention( self, embed_dim, args, add_bias_kv=False, add_zero_attn=False ): return MultiheadAttention( embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not getattr(args, "cross_self_attention", False), q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, scale_factor=args.attn_scale_factor, scale_heads=getattr(args, 'scale_heads', False) ) def build_encoder_attention(self, embed_dim, args): return MultiheadAttention( embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, scale_factor=args.attn_scale_factor, scale_heads=getattr(args, 'scale_heads', False) ) def prepare_for_onnx_export_(self): self.onnx_trace = True def residual_connection(self, x, residual): return residual + self.drop_path(x) def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, self_attn_bias: Optional[Tensor] = None, cross_attn_bias: Optional[Tensor] = None ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, attn_bias=self_attn_bias ) if self.self_attn_ln is not None: x = self.self_attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None and encoder_out is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, attn_bias=cross_attn_bias ) if self.cross_attn_ln is not None: x = self.cross_attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) if self.w_resid is not None: residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ # update layer norms layer_norm_map = { "0": "self_attn_layer_norm", "1": "encoder_attn_layer_norm", "2": "final_layer_norm", } for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: state_dict[ "{}.{}.{}".format(name, new, m) ] = state_dict[k] del state_dict[k] if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict(): state_dict[ "{}.{}.{}".format(name, new, m) ] = self.state_dict()["{}.{}".format(new, m)] prefix = name + "." if name != "" else "" for param_name, param_tensor in self.state_dict().items(): if (prefix + param_name) not in state_dict and param_name in self.state_dict(): state_dict[prefix + param_name] = self.state_dict()[param_name]