OFA-Image_Caption / fairseq /fairseq /tasks /translation_from_pretrained_bart.py
JustinLin610
update
8437114
raw history blame
No virus
5.24 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from . import register_task
from .translation import TranslationTask, load_langpair_dataset
@register_task("translation_from_pretrained_bart")
class TranslationFromPretrainedBARTTask(TranslationTask):
"""
Translate from source language to target language with a model initialized with a multilingual pretrain.
Args:
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--langs', type=str, metavar='LANG',
help='comma-separated list of monolingual language, '
'for example, "en,de,fr". These should match the '
'langs from pretraining (and be in the same order). '
'You should always add all pretraining language idx '
'during finetuning.')
parser.add_argument('--prepend-bos', action='store_true',
help='prepend bos token to each sentence, which matches '
'mBART pretraining')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
self.langs = args.langs.split(",")
for d in [src_dict, tgt_dict]:
for l in self.langs:
d.add_symbol("[{}]".format(l))
d.add_symbol("<mask>")
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
split,
src,
self.src_dict,
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=getattr(self.args, "max_source_positions", 1024),
max_target_positions=getattr(self.args, "max_target_positions", 1024),
load_alignments=self.args.load_alignments,
prepend_bos=getattr(self.args, "prepend_bos", False),
append_source_id=True,
)
def build_generator(self, models, args, **unused):
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang))
source_tokens = []
for s_t in src_tokens:
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
source_tokens.append(s_t)
dataset = LanguagePairDataset(
source_tokens,
src_lengths,
self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints,
)
return dataset