Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch.nn as nn | |
from fairseq.models.transformer import TransformerEncoder | |
from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer | |
class LinformerTransformerEncoder(TransformerEncoder): | |
""" | |
Implementation for a Bi-directional Linformer based Sentence Encoder used | |
in BERT/XLM style pre-trained models. | |
This first computes the token embedding using the token embedding matrix, | |
position embeddings (if specified) and segment embeddings | |
(if specified). After applying the specified number of | |
LinformerEncoderLayers, it outputs all the internal states of the | |
encoder as well as the final representation associated with the first | |
token (usually CLS token). | |
Input: | |
- tokens: B x T matrix representing sentences | |
- segment_labels: B x T matrix representing segment label for tokens | |
Output: | |
- a tuple of the following: | |
- a list of internal model states used to compute the | |
predictions where each tensor has shape T x B x C | |
- sentence representation associated with first input token | |
in format B x C. | |
""" | |
def __init__(self, args, dictionary, embed_tokens): | |
self.compress_layer = None | |
super().__init__(args, dictionary, embed_tokens) | |
def build_encoder_layer(self, args): | |
if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None: | |
compress_layer = nn.Linear( | |
self.args.max_positions, | |
self.args.max_positions // self.args.compressed, | |
) | |
# intialize parameters for compressed layer | |
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) | |
if self.args.freeze_compress == 1: | |
compress_layer.weight.requires_grad = False | |
self.compress_layer = compress_layer | |
return LinformerTransformerEncoderLayer(args, self.compress_layer) | |