OFA-Image_Caption / fairseq /fairseq /models /multilingual_transformer.py
JustinLin610
update
8437114
raw history blame
No virus
9.57 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from fairseq import utils
from fairseq.models import (
FairseqMultiModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
Embedding,
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture,
)
from fairseq.utils import safe_hasattr
@register_model("multilingual_transformer")
class MultilingualTransformerModel(FairseqMultiModel):
"""Train Transformer models for multiple language pairs simultaneously.
Requires `--task multilingual_translation`.
We inherit all arguments from TransformerModel and assume that all language
pairs use a single Transformer architecture. In addition, we provide several
options that are specific to the multilingual setting.
Args:
--share-encoder-embeddings: share encoder embeddings across all source languages
--share-decoder-embeddings: share decoder embeddings across all target languages
--share-encoders: share all encoder params (incl. embeddings) across all source languages
--share-decoders: share all decoder params (incl. embeddings) across all target languages
"""
def __init__(self, encoders, decoders):
super().__init__(encoders, decoders)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument(
"--share-encoder-embeddings",
action="store_true",
help="share encoder embeddings across languages",
)
parser.add_argument(
"--share-decoder-embeddings",
action="store_true",
help="share decoder embeddings across languages",
)
parser.add_argument(
"--share-encoders",
action="store_true",
help="share encoders across languages",
)
parser.add_argument(
"--share-decoders",
action="store_true",
help="share decoders across languages",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
assert isinstance(task, MultilingualTranslationTask)
# make sure all arguments are present in older models
base_multilingual_architecture(args)
if not safe_hasattr(args, "max_source_positions"):
args.max_source_positions = 1024
if not safe_hasattr(args, "max_target_positions"):
args.max_target_positions = 1024
src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
if args.share_encoders:
args.share_encoder_embeddings = True
if args.share_decoders:
args.share_decoder_embeddings = True
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
# build shared embeddings (if applicable)
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
if args.share_all_embeddings:
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=task.langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
shared_decoder_embed_tokens = shared_encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
if args.share_encoder_embeddings:
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=src_langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
if args.share_decoder_embeddings:
shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=tgt_langs,
embed_dim=args.decoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.decoder_embed_path,
)
# encoders/decoders for each language
lang_encoders, lang_decoders = {}, {}
def get_encoder(lang):
if lang not in lang_encoders:
if shared_encoder_embed_tokens is not None:
encoder_embed_tokens = shared_encoder_embed_tokens
else:
encoder_embed_tokens = build_embedding(
task.dicts[lang],
args.encoder_embed_dim,
args.encoder_embed_path,
)
lang_encoders[lang] = cls._get_module_class(
True, args, task.dicts[lang], encoder_embed_tokens, src_langs
)
return lang_encoders[lang]
def get_decoder(lang):
if lang not in lang_decoders:
if shared_decoder_embed_tokens is not None:
decoder_embed_tokens = shared_decoder_embed_tokens
else:
decoder_embed_tokens = build_embedding(
task.dicts[lang],
args.decoder_embed_dim,
args.decoder_embed_path,
)
lang_decoders[lang] = cls._get_module_class(
False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
)
return lang_decoders[lang]
# shared encoders/decoders (if applicable)
shared_encoder, shared_decoder = None, None
if args.share_encoders:
shared_encoder = get_encoder(src_langs[0])
if args.share_decoders:
shared_decoder = get_decoder(tgt_langs[0])
encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
encoders[lang_pair] = (
shared_encoder if shared_encoder is not None else get_encoder(src)
)
decoders[lang_pair] = (
shared_decoder if shared_decoder is not None else get_decoder(tgt)
)
return MultilingualTransformerModel(encoders, decoders)
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
module_class = TransformerEncoder if is_encoder else TransformerDecoder
return module_class(args, lang_dict, embed_tokens)
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
state_dict_subset = state_dict.copy()
for k, _ in state_dict.items():
assert k.startswith("models.")
lang_pair = k.split(".")[1]
if lang_pair not in self.models:
del state_dict_subset[k]
super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
@register_model_architecture("multilingual_transformer", "multilingual_transformer")
def base_multilingual_architecture(args):
base_architecture(args)
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False)
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False)
args.share_encoders = getattr(args, "share_encoders", False)
args.share_decoders = getattr(args, "share_decoders", False)
@register_model_architecture(
"multilingual_transformer", "multilingual_transformer_iwslt_de_en"
)
def multilingual_transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 6)
base_multilingual_architecture(args)