# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # import math import torch import torch.nn as nn import torch.nn.functional as F from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.scalar_bias import scalar_bias class SingleHeadAttention(nn.Module): """ Single-head attention that supports Gating and Downsampling """ def __init__( self, out_channels, embed_dim, head_dim, head_index, dropout=0.0, bias=True, project_input=True, gated=False, downsample=False, num_heads=1, ): super().__init__() self.embed_dim = embed_dim self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) self.head_index = head_index self.head_dim = head_dim self.project_input = project_input self.gated = gated self.downsample = downsample self.num_heads = num_heads self.projection = None k_layers = [] v_layers = [] if self.downsample: k_layers.append(Downsample(self.head_index)) v_layers.append(Downsample(self.head_index)) out_proj_size = self.head_dim else: out_proj_size = self.head_dim * self.num_heads if self.gated: k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias) v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) else: k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias) v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) self.in_proj_k = nn.Sequential(*k_layers) self.in_proj_v = nn.Sequential(*v_layers) if self.downsample: self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias) else: self.out_proj = Linear(out_proj_size, out_channels, bias=bias) self.scaling = self.head_dim ** -0.5 def forward( self, query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False, ): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Future timesteps can be masked with the `mask_future_timesteps` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ src_len, bsz, out_channels = key.size() tgt_len = query.size(0) assert list(query.size()) == [tgt_len, bsz, out_channels] assert key.size() == value.size() if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.downsample: size = bsz else: size = bsz * self.num_heads k = key v = value q = query if self.project_input: q = self.in_proj_q(q) k = self.in_proj_k(k) v = self.in_proj_v(v) src_len = k.size()[0] q *= self.scaling if not self.downsample: q = q.view(tgt_len, size, self.head_dim) k = k.view(src_len, size, self.head_dim) v = v.view(src_len, size, self.head_dim) q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_weights = torch.bmm(q, k.transpose(1, 2)) if mask_future_timesteps: assert ( query.size() == key.size() ), "mask_future_timesteps only applies to self-attention" attn_weights *= torch.tril( attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), diagonal=-1, )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) attn_weights += torch.triu( attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), diagonal=0, )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) tgt_size = tgt_len if use_scalar_bias: attn_weights = scalar_bias(attn_weights, 2) v = scalar_bias(v, 1) tgt_size += 1 if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask.max() > 0: if self.downsample: attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) else: attn_weights = attn_weights.view( size, self.num_heads, tgt_len, src_len ) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -math.inf, ) attn_weights = attn_weights.view(size, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout_module(attn_weights) attn = torch.bmm(attn_weights, v) if self.downsample: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) attn = self.out_proj(attn) return attn, attn_weights class DownsampledMultiHeadAttention(nn.ModuleList): """ Multi-headed attention with Gating and Downsampling """ def __init__( self, out_channels, embed_dim, num_heads, dropout=0.0, bias=True, project_input=True, gated=False, downsample=False, ): self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.downsample = downsample self.gated = gated self.project_input = project_input assert self.head_dim * num_heads == embed_dim if self.downsample: attention_heads = [] for index in range(self.num_heads): attention_heads.append( SingleHeadAttention( out_channels, self.embed_dim, self.head_dim, index, dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads, ) ) super().__init__(modules=attention_heads) self.out_proj = Linear(embed_dim, out_channels, bias=bias) else: # either we have a list of attention heads, or just one attention head # if not being downsampled, we can do the heads with one linear layer instead of separate ones super().__init__() self.attention_module = SingleHeadAttention( out_channels, self.embed_dim, self.head_dim, 1, dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads, ) def forward( self, query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False, ): src_len, bsz, embed_dim = key.size() tgt_len = query.size(0) assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] assert key.size() == value.size() tgt_size = tgt_len if use_scalar_bias: tgt_size += 1 attn = [] attn_weights = [] if self.downsample: for attention_head_number in range(self.num_heads): # call the forward of each attention head _attn, _attn_weight = self[attention_head_number]( query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias, ) attn.append(_attn) attn_weights.append(_attn_weight) full_attn = torch.cat(attn, dim=2) full_attn = self.out_proj(full_attn) return full_attn, attn_weights[0].clone() else: _attn, _attn_weight = self.attention_module( query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias, ) attn.append(_attn) attn_weights.append(_attn_weight) full_attn = torch.cat(attn, dim=2) full_attn_weights = torch.cat(attn_weights) full_attn_weights = full_attn_weights.view( bsz, self.num_heads, tgt_size, src_len ) full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads return full_attn, full_attn_weights class Downsample(nn.Module): """ Selects every nth element, where n is the index """ def __init__(self, index): super().__init__() self.index = index def forward(self, x): return x[:: self.index + 1] def Linear(in_features, out_features, dropout=0.0, bias=True): """Weight-normalized Linear layer (input: B x T x C)""" m = nn.Linear(in_features, out_features, bias=bias) m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) m.bias.data.zero_() return nn.utils.weight_norm(m) def GatedLinear(in_features, out_features, dropout=0.0, bias=True): """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units""" return nn.Sequential( Linear(in_features, out_features * 4, dropout, bias), nn.GLU(), Linear(out_features * 2, out_features * 2, dropout, bias), nn.GLU(), Linear(out_features, out_features, dropout, bias), )