OFA-OCR / fairseq /examples /linformer /linformer_src /modules /linformer_sentence_encoder.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
2.15 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch.nn as nn
from fairseq.models.transformer import TransformerEncoder
from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer
class LinformerTransformerEncoder(TransformerEncoder):
"""
Implementation for a Bi-directional Linformer based Sentence Encoder used
in BERT/XLM style pre-trained models.
This first computes the token embedding using the token embedding matrix,
position embeddings (if specified) and segment embeddings
(if specified). After applying the specified number of
LinformerEncoderLayers, it outputs all the internal states of the
encoder as well as the final representation associated with the first
token (usually CLS token).
Input:
- tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
Output:
- a tuple of the following:
- a list of internal model states used to compute the
predictions where each tensor has shape T x B x C
- sentence representation associated with first input token
in format B x C.
"""
def __init__(self, args, dictionary, embed_tokens):
self.compress_layer = None
super().__init__(args, dictionary, embed_tokens)
def build_encoder_layer(self, args):
if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None:
compress_layer = nn.Linear(
self.args.max_positions,
self.args.max_positions // self.args.compressed,
)
# intialize parameters for compressed layer
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
if self.args.freeze_compress == 1:
compress_layer.weight.requires_grad = False
self.compress_layer = compress_layer
return LinformerTransformerEncoderLayer(args, self.compress_layer)