OFA-Image_Caption / fairseq /fairseq /tasks /translation_multi_simple_epoch.py
JustinLin610
update
8437114
raw history blame
No virus
17.3 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
import time
import torch
from fairseq.data import (
FairseqDataset,
LanguagePairDataset,
ListDataset,
data_utils,
iterators,
)
from fairseq.data.multilingual.multilingual_data_manager import (
MultilingualDatasetManager,
)
from fairseq.data.multilingual.sampling_method import SamplingMethod
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.utils import FileContentsAction
###
def get_time_gap(s, e):
return (
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
).__str__()
###
logger = logging.getLogger(__name__)
@register_task("translation_multi_simple_epoch")
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.
Args:
langs (List[str]): a list of languages that are being supported
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries
training (bool): whether the task should be configured for training or not
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='inference source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='inference target language')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr',
action=FileContentsAction)
parser.add_argument('--keep-inference-langtok', action='store_true',
help='keep language tokens in inference output (e.g. for analysis or debugging)')
SamplingMethod.add_arguments(parser)
MultilingualDatasetManager.add_args(parser)
# fmt: on
def __init__(self, args, langs, dicts, training):
super().__init__(args)
self.langs = langs
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.source_langs = [d.split("-")[0] for d in self.lang_pairs]
self.target_langs = [d.split("-")[1] for d in self.lang_pairs]
self.check_dicts(self.dicts, self.source_langs, self.target_langs)
self.sampling_method = SamplingMethod.build_sampler(args, self)
self.data_manager = MultilingualDatasetManager.setup_data_manager(
args, self.lang_pairs, langs, dicts, self.sampling_method
)
def check_dicts(self, dicts, source_langs, target_langs):
if self.args.source_dict is not None or self.args.target_dict is not None:
# no need to check whether the source side and target side are sharing dictionaries
return
src_dict = dicts[source_langs[0]]
tgt_dict = dicts[target_langs[0]]
for src_lang in source_langs:
assert (
src_dict == dicts[src_lang]
), "Diffrent dictionary are specified for different source languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages"
for tgt_lang in target_langs:
assert (
tgt_dict == dicts[tgt_lang]
), "Diffrent dictionary are specified for different target languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages"
@classmethod
def setup_task(cls, args, **kwargs):
langs, dicts, training = MultilingualDatasetManager.prepare(
cls.load_dictionary, args, **kwargs
)
return cls(args, langs, dicts, training)
def has_sharded_data(self, split):
return self.data_manager.has_sharded_data(split)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if split in self.datasets:
dataset = self.datasets[split]
if self.has_sharded_data(split):
if self.args.virtual_epoch_size is not None:
if dataset.load_next_shard:
shard_epoch = dataset.shard_epoch
else:
# no need to load next shard so skip loading
# also this avoid always loading from beginning of the data
return
else:
shard_epoch = epoch
else:
# estimate the shard epoch from virtual data size and virtual epoch size
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch)
logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
if split in self.datasets:
del self.datasets[split]
logger.info("old dataset deleted manually")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
self.datasets[split] = self.data_manager.load_dataset(
split,
self.training,
epoch=epoch,
combine=combine,
shard_epoch=shard_epoch,
**kwargs,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the multilingual_translation task is not supported"
)
src_data = ListDataset(src_tokens, src_lengths)
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
if self.args.lang_tok_replacing_bos_eos:
dataset = self.data_manager.alter_dataset_langtok(
dataset,
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_eos=self.target_dictionary.eos(),
tgt_lang=self.args.target_lang,
src_langtok_spec=src_langtok_spec,
tgt_langtok_spec=tgt_langtok_spec,
)
else:
dataset.src = self.data_manager.src_dataset_tranform_func(
self.args.source_lang,
self.args.target_lang,
dataset=dataset.src,
spec=src_langtok_spec,
)
return dataset
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
if not getattr(args, "keep_inference_langtok", False):
_, tgt_langtok_spec = self.args.langtoks["main"]
if tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok}
return super().build_generator(
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_model(self, args):
return super().build_model(args)
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
_, tgt_langtok_spec = self.args.langtoks["main"]
if not self.args.lang_tok_replacing_bos_eos:
if prefix_tokens is None and tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
prefix_tokens = (
torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens)
)
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)
else:
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
if tgt_langtok_spec
else self.target_dictionary.eos(),
)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def source_dictionary(self):
return self.data_manager.get_source_dictionary(self.source_langs[0])
@property
def target_dictionary(self):
return self.data_manager.get_target_dictionary(self.target_langs[0])
def create_batch_sampler_func(
self,
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=1,
seed=1,
):
def construct_batch_sampler(dataset, epoch):
splits = [
s for s, _ in self.datasets.items() if self.datasets[s] == dataset
]
split = splits[0] if len(splits) > 0 else None
# NEW implementation
if epoch is not None:
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
start_time = time.time()
logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}")
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
logger.info(
f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# filter examples that are too large
if max_positions is not None:
my_time = time.time()
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
logger.info(
f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# create mini-batches with given size constraints
my_time = time.time()
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
logger.info(
f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
)
logger.info(
f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
return batch_sampler
return construct_batch_sampler
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# initialize the dataset with the correct starting epoch
assert isinstance(dataset, FairseqDataset)
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
if self.args.sampling_method == "RoundRobin":
batch_iter = super().get_batch_iterator(
dataset,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
data_buffer_size=data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
)
self.dataset_to_epoch_iter[dataset] = batch_iter
return batch_iter
construct_batch_sampler = self.create_batch_sampler_func(
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
)
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=construct_batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter