# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn import torch.nn.functional as F from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import TransformerEncoder, TransformerModel @register_model("gru_transformer") class GRUTransformerModel(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return GRUTransformerEncoder(args, src_dict, embed_tokens) class GRUTransformerEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): super().__init__(args, dictionary, embed_tokens) self.emb_ctx = nn.GRU( input_size=embed_tokens.embedding_dim, hidden_size=embed_tokens.embedding_dim // 2, num_layers=1, bidirectional=True, ) def forward_embedding(self, src_tokens): # embed tokens and positions x = embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) # contextualize embeddings x = x.transpose(0, 1) x = self.dropout_module(x) x, _ = self.emb_ctx.forward(x) x = x.transpose(0, 1) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) return x, embed @register_model_architecture("gru_transformer", "gru_transformer") def gru_transformer_base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.no_cross_attention = getattr(args, "no_cross_attention", False) args.cross_self_attention = getattr(args, "cross_self_attention", False) args.layer_wise_attention = getattr(args, "layer_wise_attention", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) @register_model_architecture("gru_transformer", "gru_transformer_big") def gru_transformer_big(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) args.dropout = getattr(args, "dropout", 0.3) gru_transformer_base_architecture(args)