# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import math import os import torch import torch.nn as nn import torch.nn.functional as F from fairseq import checkpoint_utils from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.models import ( CompositeEncoder, FairseqDecoder, FairseqEncoder, FairseqEncoderDecoderModel, register_model, register_model_architecture, ) from fairseq.modules import ( DownsampledMultiHeadAttention, FairseqDropout, GradMultiply, LayerNorm, LearnedPositionalEmbedding, LinearizedConvolution, ) logger = logging.getLogger(__name__) @register_model("fconv_self_att") class FConvModelSelfAtt(FairseqEncoderDecoderModel): @classmethod def hub_models(cls): return { "conv.stories.pretrained": { "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", "checkpoint_file": "pretrained_checkpoint.pt", "tokenizer": "nltk", }, "conv.stories": { "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", "checkpoint_file": "fusion_checkpoint.pt", "tokenizer": "nltk", "pretrained": "True", "pretrained_checkpoint": "./pretrained_checkpoint.pt", }, # Test set containing dictionaries "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2", } def __init__(self, encoder, decoder, pretrained_encoder=None): super().__init__(encoder, decoder) self.encoder.num_attention_layers = sum( layer is not None for layer in decoder.attention ) self.pretrained_encoder = pretrained_encoder if self.pretrained_encoder is None: encoders = {"encoder": encoder} else: encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder} # for fusion model, CompositeEncoder contains both pretrained and training encoders # these are forwarded and then combined in the decoder self.encoder = CompositeEncoder(encoders) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') parser.add_argument('--encoder-embed-dim', type=int, metavar='N', help='encoder embedding dimension') parser.add_argument('--encoder-layers', type=str, metavar='EXPR', help='encoder layers [(dim, kernel_size), ...]') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') parser.add_argument('--decoder-layers', type=str, metavar='EXPR', help='decoder layers [(dim, kernel_size), ...]') parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', help='decoder output embedding dimension') parser.add_argument('--decoder-attention', type=str, metavar='EXPR', help='decoder attention [True, ...]') parser.add_argument('--self-attention', type=str, metavar='EXPR', help='decoder self-attention layers, ex: [True] + [False]*5') parser.add_argument('--multihead-attention-nheads', type=int, help='Number of heads to use in attention') parser.add_argument('--multihead-self-attention-nheads', type=int, help='Number of heads to use in self-attention') parser.add_argument('--encoder-attention', type=str, metavar='EXPR', help='encoder attention [True, ...]') parser.add_argument('--encoder-attention-nheads', type=int, help='Number of heads to use in encoder attention') parser.add_argument('--project-input', type=str, metavar='EXPR', help='Use projections in self-attention [True, ...]') parser.add_argument('--gated-attention', type=str, metavar='EXPR', help='Use GLU layers in self-attention projections [True, ...]') parser.add_argument('--downsample', type=str, metavar='EXPR', help='Use downsampling in self-attention [True, ...]') parser.add_argument('--pretrained-checkpoint', metavar='DIR', help='path to load checkpoint from pretrained model') parser.add_argument('--pretrained', type=str, metavar='EXPR', help='use pretrained model when training [True, ...]') # fmt: on @classmethod def build_model(cls, args, task): """Build a new model instance.""" trained_encoder, trained_decoder = None, None pretrained = eval(args.pretrained) if pretrained: logger.info("loading pretrained model") if not os.path.exists(args.pretrained_checkpoint): new_pretrained_checkpoint = os.path.join( args.data, args.pretrained_checkpoint ) if os.path.exists(new_pretrained_checkpoint): args.pretrained_checkpoint = new_pretrained_checkpoint trained_model = checkpoint_utils.load_model_ensemble( filenames=[args.pretrained_checkpoint], task=task, )[0][0] trained_decoder = list(trained_model.children())[1] trained_encoder = list(trained_model.children())[0] # freeze pretrained model for param in trained_decoder.parameters(): param.requires_grad = False for param in trained_encoder.parameters(): param.requires_grad = False encoder = FConvEncoder( task.source_dictionary, embed_dim=args.encoder_embed_dim, convolutions=eval(args.encoder_layers), dropout=args.dropout, max_positions=args.max_source_positions, attention=eval(args.encoder_attention), attention_nheads=args.encoder_attention_nheads, ) decoder = FConvDecoder( task.target_dictionary, embed_dim=args.decoder_embed_dim, convolutions=eval(args.decoder_layers), out_embed_dim=args.decoder_out_embed_dim, attention=eval(args.decoder_attention), dropout=args.dropout, max_positions=args.max_target_positions, selfattention=eval(args.self_attention), attention_nheads=args.multihead_attention_nheads, selfattention_nheads=args.multihead_self_attention_nheads, project_input=eval(args.project_input), gated_attention=eval(args.gated_attention), downsample=eval(args.downsample), pretrained=pretrained, trained_decoder=trained_decoder, ) model = FConvModelSelfAtt(encoder, decoder, trained_encoder) return model @property def pretrained(self): return self.pretrained_encoder is not None class FConvEncoder(FairseqEncoder): """Convolutional encoder""" def __init__( self, dictionary, embed_dim=512, max_positions=1024, convolutions=((512, 3),) * 20, dropout=0.1, attention=False, attention_nheads=1, ): super().__init__(dictionary) self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) self.num_attention_layers = None num_embeddings = len(dictionary) self.padding_idx = dictionary.pad() self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) self.embed_positions = PositionalEmbedding( max_positions, embed_dim, self.padding_idx, ) def expand_bool_array(val): if isinstance(val, bool): # expand True into [True, True, ...] and do the same with False return [val] * len(convolutions) return val attention = expand_bool_array(attention) in_channels = convolutions[0][0] self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.projections = nn.ModuleList() self.convolutions = nn.ModuleList() self.attention = nn.ModuleList() self.attproj = nn.ModuleList() for i, (out_channels, kernel_size) in enumerate(convolutions): self.projections.append( Linear(in_channels, out_channels) if in_channels != out_channels else None ) self.convolutions.append( ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout) ) self.attention.append( SelfAttention(out_channels, embed_dim, attention_nheads) if attention[i] else None ) in_channels = out_channels self.fc2 = Linear(in_channels, embed_dim) def forward(self, src_tokens, src_lengths): # embed tokens and positions x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) x = self.dropout_module(x) input_embedding = x.transpose(0, 1) # project to size of convolution x = self.fc1(x) encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B if not encoder_padding_mask.any(): encoder_padding_mask = None # B x T x C -> T x B x C x = x.transpose(0, 1) # temporal convolutions for proj, conv, attention in zip( self.projections, self.convolutions, self.attention ): residual = x if proj is None else proj(x) if encoder_padding_mask is not None: x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) x = self.dropout_module(x) padding_l = (conv.kernel_size[0] - 1) // 2 padding_r = conv.kernel_size[0] // 2 x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) x = conv(x) x = F.glu(x, dim=2) if attention is not None: x = attention(x) x = (x + residual) * math.sqrt(0.5) # T x B x C -> B x T x C x = x.transpose(1, 0) # project back to size of embedding x = self.fc2(x) if encoder_padding_mask is not None: encoder_padding_mask = encoder_padding_mask.t() # -> B x T x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) # scale gradients (this only affects backward, not forward) x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) # add output to input embedding for attention y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) return { "encoder_out": (x, y), "encoder_padding_mask": encoder_padding_mask, # B x T } def reorder_encoder_out(self, encoder_out, new_order): encoder_out["encoder_out"] = tuple( eo.index_select(0, new_order) for eo in encoder_out["encoder_out"] ) if encoder_out["encoder_padding_mask"] is not None: encoder_out["encoder_padding_mask"] = encoder_out[ "encoder_padding_mask" ].index_select(0, new_order) if "pretrained" in encoder_out: encoder_out["pretrained"]["encoder_out"] = tuple( eo.index_select(0, new_order) for eo in encoder_out["pretrained"]["encoder_out"] ) return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" return self.embed_positions.max_positions @with_incremental_state class FConvDecoder(FairseqDecoder): """Convolutional decoder""" def __init__( self, dictionary, embed_dim=512, out_embed_dim=256, max_positions=1024, convolutions=((512, 3),) * 8, attention=True, dropout=0.1, selfattention=False, attention_nheads=1, selfattention_nheads=1, project_input=False, gated_attention=False, downsample=False, pretrained=False, trained_decoder=None, ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([2])) self.pretrained = pretrained self.pretrained_decoder = trained_decoder self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) self.need_attn = True in_channels = convolutions[0][0] def expand_bool_array(val): if isinstance(val, bool): # expand True into [True, True, ...] and do the same with False return [val] * len(convolutions) return val attention = expand_bool_array(attention) selfattention = expand_bool_array(selfattention) if not isinstance(attention, list) or len(attention) != len(convolutions): raise ValueError( "Attention is expected to be a list of booleans of " "length equal to the number of layers." ) num_embeddings = len(dictionary) padding_idx = dictionary.pad() self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_positions = PositionalEmbedding( max_positions, embed_dim, padding_idx, ) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.projections = nn.ModuleList() self.convolutions = nn.ModuleList() self.attention = nn.ModuleList() self.selfattention = nn.ModuleList() self.attproj = nn.ModuleList() for i, (out_channels, kernel_size) in enumerate(convolutions): self.projections.append( Linear(in_channels, out_channels) if in_channels != out_channels else None ) self.convolutions.append( LinearizedConv1d( in_channels, out_channels * 2, kernel_size, padding=(kernel_size - 1), dropout=dropout, ) ) self.attention.append( DownsampledMultiHeadAttention( out_channels, embed_dim, attention_nheads, project_input=project_input, gated=False, downsample=False, ) if attention[i] else None ) self.attproj.append( Linear(out_channels, embed_dim, dropout=dropout) if attention[i] else None ) self.selfattention.append( SelfAttention( out_channels, embed_dim, selfattention_nheads, project_input=project_input, gated=gated_attention, downsample=downsample, ) if selfattention[i] else None ) in_channels = out_channels self.fc2 = Linear(in_channels, out_embed_dim) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) # model fusion if self.pretrained: # independent gates are learned from the concatenated input self.gate1 = nn.Sequential( Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() ) self.gate2 = nn.Sequential( Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() ) # pretrained and trained models are joined self.joining = nn.Sequential( Linear(out_embed_dim * 2, out_embed_dim * 2), LayerNorm(out_embed_dim * 2), nn.GLU(), Linear(out_embed_dim, out_embed_dim * 2), LayerNorm(out_embed_dim * 2), nn.GLU(), Linear(out_embed_dim, out_embed_dim), LayerNorm(out_embed_dim), ) # pretrained model contains an output layer that is nhid -> vocab size # but the models are combined in their hidden state # the hook stores the output of the pretrained model forward self.pretrained_outputs = {} def save_output(): def hook(a, b, output): self.pretrained_outputs["out"] = output return hook self.pretrained_decoder.fc2.register_forward_hook(save_output()) def forward(self, prev_output_tokens, encoder_out): trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None encoder_out = encoder_out["encoder"]["encoder_out"] encoder_a, encoder_b = self._split_encoder_out(encoder_out) # embed positions positions = self.embed_positions(prev_output_tokens) # embed tokens and positions x = self.embed_tokens(prev_output_tokens) + positions x = self.dropout_module(x) target_embedding = x.transpose(0, 1) # project to size of convolution x = self.fc1(x) # B x T x C -> T x B x C x = x.transpose(0, 1) # temporal convolutions avg_attn_scores = None for proj, conv, attention, selfattention, attproj in zip( self.projections, self.convolutions, self.attention, self.selfattention, self.attproj, ): residual = x if proj is None else proj(x) x = self.dropout_module(x) x = conv(x) x = F.glu(x, dim=2) # attention if attention is not None: r = x x, attn_scores = attention( attproj(x) + target_embedding, encoder_a, encoder_b ) x = x + r if not self.training and self.need_attn: if avg_attn_scores is None: avg_attn_scores = attn_scores else: avg_attn_scores.add_(attn_scores) if selfattention is not None: x = selfattention(x) x = (x + residual) * math.sqrt(0.5) # T x B x C -> B x T x C x = x.transpose(0, 1) # project back to size of vocabulary x = self.fc2(x) x = self.dropout_module(x) if not self.pretrained: x = self.fc3(x) # fusion gating if self.pretrained: trained_x, _ = self.pretrained_decoder.forward( prev_output_tokens, trained_encoder_out ) y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1) gate1 = self.gate1(y) gate2 = self.gate2(y) gated_x1 = gate1 * x gated_x2 = gate2 * self.pretrained_outputs["out"] fusion = torch.cat([gated_x1, gated_x2], dim=-1) fusion = self.joining(fusion) fusion_output = self.fc3(fusion) return fusion_output, avg_attn_scores else: return x, avg_attn_scores def max_positions(self): """Maximum output length supported by the decoder.""" return self.embed_positions.max_positions def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn def _split_encoder_out(self, encoder_out): """Split and transpose encoder outputs.""" # transpose only once to speed up attention layers encoder_a, encoder_b = encoder_out encoder_a = encoder_a.transpose(0, 1).contiguous() encoder_b = encoder_b.transpose(0, 1).contiguous() result = (encoder_a, encoder_b) return result class SelfAttention(nn.Module): def __init__( self, out_channels, embed_dim, num_heads, project_input=False, gated=False, downsample=False, ): super().__init__() self.attention = DownsampledMultiHeadAttention( out_channels, embed_dim, num_heads, dropout=0, bias=True, project_input=project_input, gated=gated, downsample=downsample, ) self.in_proj_q = Linear(out_channels, embed_dim) self.in_proj_k = Linear(out_channels, embed_dim) self.in_proj_v = Linear(out_channels, embed_dim) self.ln = LayerNorm(out_channels) def forward(self, x): residual = x query = self.in_proj_q(x) key = self.in_proj_k(x) value = self.in_proj_v(x) x, _ = self.attention( query, key, value, mask_future_timesteps=True, use_scalar_bias=True ) return self.ln(x + residual) def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m.weight.data.normal_(0, 0.1) return m def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) m.weight.data.normal_(0, 0.1) return m def Linear(in_features, out_features, dropout=0.0): """Weight-normalized Linear layer (input: N x T x C)""" m = nn.Linear(in_features, out_features) m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) m.bias.data.zero_() return m def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): """Weight-normalized Conv1d layer optimized for decoding""" m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) m.weight.data.normal_(mean=0, std=std) m.bias.data.zero_() return m def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): """Weight-normalized Conv1d layer""" from fairseq.modules import ConvTBC m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) m.weight.data.normal_(mean=0, std=std) m.bias.data.zero_() return m @register_model_architecture("fconv_self_att", "fconv_self_att") def base_architecture(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3") args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8") args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) args.decoder_attention = getattr(args, "decoder_attention", "True") args.self_attention = getattr(args, "self_attention", "False") args.encoder_attention = getattr(args, "encoder_attention", "False") args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1) args.multihead_self_attention_nheads = getattr( args, "multihead_self_attention_nheads", 1 ) args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1) args.project_input = getattr(args, "project_input", "False") args.gated_attention = getattr(args, "gated_attention", "False") args.downsample = getattr(args, "downsample", "False") args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "") args.pretrained = getattr(args, "pretrained", "False") @register_model_architecture("fconv_self_att", "fconv_self_att_wp") def fconv_self_att_wp(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_layers = getattr( args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1" ) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) args.decoder_layers = getattr( args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1" ) args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) args.self_attention = getattr(args, "self_attention", "True") args.multihead_self_attention_nheads = getattr( args, "multihead_self_attention_nheads", 4 ) args.project_input = getattr(args, "project_input", "True") args.gated_attention = getattr(args, "gated_attention", "True") args.downsample = getattr(args, "downsample", "True") base_architecture(args)