Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Linformer: Self-Attention with Linear Complexity | |
""" | |
import logging | |
import torch | |
from fairseq import utils | |
from fairseq.models import register_model, register_model_architecture | |
from fairseq.models.roberta import ( | |
init_bert_params, | |
roberta_base_architecture, | |
roberta_large_architecture, | |
RobertaEncoder, | |
RobertaModel, | |
) | |
from fairseq.utils import safe_hasattr | |
from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder | |
logger = logging.getLogger(__name__) | |
class LinformerModel(RobertaModel): | |
def add_args(parser): | |
RobertaModel.add_args(parser) | |
# add args for Linformer | |
parser.add_argument( | |
"--compressed", type=int, help="compressed ratio of sequence length" | |
) | |
parser.add_argument( | |
"--shared-kv-compressed", | |
type=int, | |
help="share compressed matrix between k and v, in each layer", | |
) | |
parser.add_argument( | |
"--shared-layer-kv-compressed", | |
type=int, | |
help="share compressed matrix between k and v and across all layers", | |
) | |
parser.add_argument( | |
"--freeze-compress", | |
type=int, | |
help="freeze the parameters in compressed layer", | |
) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
# make sure all arguments are present | |
base_architecture(args) | |
if not safe_hasattr(args, "max_positions"): | |
args.max_positions = args.tokens_per_sample | |
encoder = LinformerEncoder(args, task.source_dictionary) | |
return cls(args, encoder) | |
class LinformerEncoder(RobertaEncoder): | |
"""Linformer encoder.""" | |
def __init__(self, args, dictionary): | |
super().__init__(args, dictionary) | |
self.register_buffer("version", torch.tensor(2)) | |
def build_encoder(self, args, dictionary, embed_tokens): | |
encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens) | |
encoder.apply(init_bert_params) | |
return encoder | |
def upgrade_state_dict_named(self, state_dict, name): | |
super().upgrade_state_dict_named(state_dict, name) | |
prefix = name + "." if name != "" else "" | |
# some old checkpoints had weight sharing implemented incorrectly | |
# (note: this was correct in the original paper code) | |
if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2: | |
state_dict[f"{prefix}version"] = torch.tensor(1) | |
# check if input embeddings and output embeddings were tied | |
if not torch.allclose( | |
state_dict[f"{prefix}sentence_encoder.embed_tokens.weight"], | |
state_dict[f"{prefix}lm_head.weight"], | |
): | |
# they weren't tied, re-init the LM head without weight sharing | |
self.lm_head = self.build_lm_head( | |
embed_dim=self.args.encoder_embed_dim, | |
output_dim=len(self.dictionary), | |
activation_fn=self.args.activation_fn, | |
weight=None, # don't share weights | |
) | |
def base_architecture(args): | |
args.compressed = getattr(args, "compressed", 4) | |
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) | |
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) | |
args.freeze_compress = getattr(args, "freeze_compress", 0) | |
roberta_base_architecture(args) | |
def linformer_roberta_base_architecture(args): | |
base_architecture(args) | |
def linformer_roberta_large_architecture(args): | |
roberta_large_architecture(args) | |
base_architecture(args) | |