OFA-Image_Caption / fairseq /fairseq /modules /transformer_sentence_encoder_layer.py
JustinLin610
update
8437114
raw history blame
No virus
4.33 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
export: bool = False,
q_noise: float = 0.0,
qn_block_size: int = 8,
init_fn: Callable = None,
) -> None:
super().__init__()
if init_fn is not None:
init_fn()
# Initialize parameters
self.embedding_dim = embedding_dim
self.num_attention_heads = num_attention_heads
self.attention_dropout = attention_dropout
self.q_noise = q_noise
self.qn_block_size = qn_block_size
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.activation_dropout_module = FairseqDropout(
activation_dropout, module_name=self.__class__.__name__
)
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = self.build_self_attention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)
self.fc1 = self.build_fc1(
self.embedding_dim,
ffn_embedding_dim,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
self.fc2 = self.build_fc2(
ffn_embedding_dim,
self.embedding_dim,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self,
embed_dim,
num_attention_heads,
dropout,
self_attention,
q_noise,
qn_block_size,
):
return MultiheadAttention(
embed_dim,
num_attention_heads,
dropout=dropout,
self_attention=True,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
def forward(
self,
x: torch.Tensor,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
residual = x
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn