JustinLin610
update
8437114
raw history blame
No virus
15.2 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LayerNorm,
SinusoidalPositionalEmbedding,
TransformerSentenceEncoder,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import safe_hasattr
logger = logging.getLogger(__name__)
@register_model("masked_lm")
class MaskedLMModel(FairseqEncoderModel):
"""
Class for training a Masked Language Model. It also supports an
additional sentence level prediction if the sent-loss argument is set.
"""
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
# if specified then apply bert initialization on the model. We need
# to explictly call this to make sure that the output embeddings
# and projection layers are also correctly initialized
if getattr(args, "apply_bert_init", False):
self.apply(init_bert_params)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# Arguments related to dropout
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for" " attention weights",
)
parser.add_argument(
"--act-dropout",
type=float,
metavar="D",
help="dropout probability after" " activation in FFN",
)
# Arguments related to hidden states and self-attention
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
# Arguments related to input and output embeddings
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--share-encoder-input-output-embed",
action="store_true",
help="share encoder input" " and output embeddings",
)
parser.add_argument(
"--encoder-learned-pos",
action="store_true",
help="use learned positional embeddings in the encoder",
)
parser.add_argument(
"--no-token-positional-embeddings",
action="store_true",
help="if set, disables positional embeddings" " (outside self attention)",
)
parser.add_argument(
"--num-segment", type=int, metavar="N", help="num segment in the input"
)
parser.add_argument(
"--max-positions", type=int, help="number of positional embeddings to learn"
)
# Arguments related to sentence level prediction
parser.add_argument(
"--sentence-class-num",
type=int,
metavar="N",
help="number of classes for sentence task",
)
parser.add_argument(
"--sent-loss",
action="store_true",
help="if set," " calculate sentence level predictions",
)
# Arguments related to parameter initialization
parser.add_argument(
"--apply-bert-init",
action="store_true",
help="use custom param initialization for BERT",
)
# misc params
parser.add_argument(
"--activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
help="Which activation function to use for pooler layer.",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
def forward(self, src_tokens, segment_labels=None, **kwargs):
return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
def max_positions(self):
return self.encoder.max_positions
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not safe_hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
logger.info(args)
encoder = MaskedLMEncoder(args, task.dictionary)
return cls(args, encoder)
class MaskedLMEncoder(FairseqEncoder):
"""
Encoder for Masked Language Modelling.
"""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.padding_idx = dictionary.pad()
self.vocab_size = dictionary.__len__()
self.max_positions = args.max_positions
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=self.padding_idx,
vocab_size=self.vocab_size,
num_encoder_layers=args.encoder_layers,
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.act_dropout,
max_seq_len=self.max_positions,
num_segments=args.num_segment,
use_position_embeddings=not args.no_token_positional_embeddings,
encoder_normalize_before=args.encoder_normalize_before,
apply_bert_init=args.apply_bert_init,
activation_fn=args.activation_fn,
learned_pos_embedding=args.encoder_learned_pos,
)
self.share_input_output_embed = args.share_encoder_input_output_embed
self.embed_out = None
self.sentence_projection_layer = None
self.sentence_out_dim = args.sentence_class_num
self.lm_output_learned_bias = None
# Remove head is set to true during fine-tuning
self.load_softmax = not getattr(args, "remove_head", False)
self.masked_lm_pooler = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn)
self.lm_head_transform_weight = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.activation_fn = utils.get_activation_fn(args.activation_fn)
self.layer_norm = LayerNorm(args.encoder_embed_dim)
self.lm_output_learned_bias = None
if self.load_softmax:
self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size))
if not self.share_input_output_embed:
self.embed_out = nn.Linear(
args.encoder_embed_dim, self.vocab_size, bias=False
)
if args.sent_loss:
self.sentence_projection_layer = nn.Linear(
args.encoder_embed_dim, self.sentence_out_dim, bias=False
)
def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused):
"""
Forward pass for Masked LM encoder. This first computes the token
embedding using the token embedding matrix, position embeddings (if
specified) and segment embeddings (if specified).
Here we assume that the sentence representation corresponds to the
output of the classification_token (see bert_task or cross_lingual_lm
task for more details).
Args:
- src_tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
Returns:
- a tuple of the following:
- logits for predictions in format B x T x C to be used in
softmax afterwards
- a dictionary of additional data, where 'pooled_output' contains
the representation for classification_token and 'inner_states'
is a list of internal model states used to compute the
predictions (similar in ELMO). 'sentence_logits'
is the prediction logit for NSP task and is only computed if
this is specified in the input arguments.
"""
inner_states, sentence_rep = self.sentence_encoder(
src_tokens,
segment_labels=segment_labels,
)
x = inner_states[-1].transpose(0, 1)
# project masked tokens only
if masked_tokens is not None:
x = x[masked_tokens, :]
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep))
# project back to size of vocabulary
if self.share_input_output_embed and hasattr(
self.sentence_encoder.embed_tokens, "weight"
):
x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
elif self.embed_out is not None:
x = self.embed_out(x)
if self.lm_output_learned_bias is not None:
x = x + self.lm_output_learned_bias
sentence_logits = None
if self.sentence_projection_layer:
sentence_logits = self.sentence_projection_layer(pooled_output)
return x, {
"inner_states": inner_states,
"pooled_output": pooled_output,
"sentence_logits": sentence_logits,
}
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.max_positions
def upgrade_state_dict_named(self, state_dict, name):
if isinstance(
self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
):
state_dict[
name + ".sentence_encoder.embed_positions._float_tensor"
] = torch.FloatTensor(1)
if not self.load_softmax:
for k in list(state_dict.keys()):
if (
"embed_out.weight" in k
or "sentence_projection_layer.weight" in k
or "lm_output_learned_bias" in k
):
del state_dict[k]
return state_dict
@register_model_architecture("masked_lm", "masked_lm")
def base_architecture(args):
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.act_dropout = getattr(args, "act_dropout", 0.0)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.share_encoder_input_output_embed = getattr(
args, "share_encoder_input_output_embed", False
)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.num_segment = getattr(args, "num_segment", 2)
args.sentence_class_num = getattr(args, "sentence_class_num", 2)
args.sent_loss = getattr(args, "sent_loss", False)
args.apply_bert_init = getattr(args, "apply_bert_init", False)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
@register_model_architecture("masked_lm", "bert_base")
def bert_base_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.share_encoder_input_output_embed = getattr(
args, "share_encoder_input_output_embed", True
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
args.num_segment = getattr(args, "num_segment", 2)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.sentence_class_num = getattr(args, "sentence_class_num", 2)
args.sent_loss = getattr(args, "sent_loss", True)
args.apply_bert_init = getattr(args, "apply_bert_init", True)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
base_architecture(args)
@register_model_architecture("masked_lm", "bert_large")
def bert_large_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
bert_base_architecture(args)
@register_model_architecture("masked_lm", "xlm_base")
def xlm_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.share_encoder_input_output_embed = getattr(
args, "share_encoder_input_output_embed", True
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
args.num_segment = getattr(args, "num_segment", 1)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.sent_loss = getattr(args, "sent_loss", False)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.apply_bert_init = getattr(args, "apply_bert_init", True)
base_architecture(args)