# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from fairseq.modules import TransformerSentenceEncoder from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( SparseTransformerSentenceEncoderLayer, ) class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): """ Sparse implementation of the TransformerSentenceEncoder - see SparseMultiheadAttention """ def __init__( self, padding_idx: int, vocab_size: int, num_encoder_layers: int = 6, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, offset_positions_by_padding: bool = True, encoder_normalize_before: bool = False, apply_bert_init: bool = False, activation_fn: str = "relu", learned_pos_embedding: bool = True, embed_scale: float = None, freeze_embeddings: bool = False, n_trans_layers_to_freeze: int = 0, export: bool = False, is_bidirectional: bool = True, stride: int = 32, expressivity: int = 8, ) -> None: super().__init__( padding_idx, vocab_size, num_encoder_layers, embedding_dim, ffn_embedding_dim, num_attention_heads, dropout, attention_dropout, activation_dropout, max_seq_len, num_segments, use_position_embeddings, offset_positions_by_padding, encoder_normalize_before, apply_bert_init, activation_fn, learned_pos_embedding, embed_scale, freeze_embeddings, n_trans_layers_to_freeze, export, ) self.layers = nn.ModuleList( [ SparseTransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, export=export, is_bidirectional=is_bidirectional, stride=stride, expressivity=expressivity, ) for _ in range(num_encoder_layers) ] ) def freeze_module_params(m): if m is not None: for p in m.parameters(): p.requires_grad = False for layer in range(n_trans_layers_to_freeze): freeze_module_params(self.layers[layer])