Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
import itertools | |
import json | |
import logging | |
import os | |
from typing import Optional | |
from argparse import Namespace | |
from omegaconf import II | |
import numpy as np | |
from fairseq import metrics, utils | |
from fairseq.data import ( | |
AppendTokenDataset, | |
ConcatDataset, | |
LanguagePairDataset, | |
PrependTokenDataset, | |
StripTokenDataset, | |
TruncateDataset, | |
data_utils, | |
encoders, | |
indexed_dataset, | |
) | |
from fairseq.data.indexed_dataset import get_available_dataset_impl | |
from fairseq.dataclass import ChoiceEnum, FairseqDataclass | |
from fairseq.tasks import FairseqTask, register_task | |
EVAL_BLEU_ORDER = 4 | |
logger = logging.getLogger(__name__) | |
def load_langpair_dataset( | |
data_path, | |
split, | |
src, | |
src_dict, | |
tgt, | |
tgt_dict, | |
combine, | |
dataset_impl, | |
upsample_primary, | |
left_pad_source, | |
left_pad_target, | |
max_source_positions, | |
max_target_positions, | |
prepend_bos=False, | |
load_alignments=False, | |
truncate_source=False, | |
append_source_id=False, | |
num_buckets=0, | |
shuffle=True, | |
pad_to_multiple=1, | |
prepend_bos_src=None, | |
): | |
def split_exists(split, src, tgt, lang, data_path): | |
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) | |
return indexed_dataset.dataset_exists(filename, impl=dataset_impl) | |
src_datasets = [] | |
tgt_datasets = [] | |
for k in itertools.count(): | |
split_k = split + (str(k) if k > 0 else "") | |
# infer langcode | |
if split_exists(split_k, src, tgt, src, data_path): | |
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) | |
elif split_exists(split_k, tgt, src, src, data_path): | |
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) | |
else: | |
if k > 0: | |
break | |
else: | |
raise FileNotFoundError( | |
"Dataset not found: {} ({})".format(split, data_path) | |
) | |
src_dataset = data_utils.load_indexed_dataset( | |
prefix + src, src_dict, dataset_impl | |
) | |
if truncate_source: | |
src_dataset = AppendTokenDataset( | |
TruncateDataset( | |
StripTokenDataset(src_dataset, src_dict.eos()), | |
max_source_positions - 1, | |
), | |
src_dict.eos(), | |
) | |
src_datasets.append(src_dataset) | |
tgt_dataset = data_utils.load_indexed_dataset( | |
prefix + tgt, tgt_dict, dataset_impl | |
) | |
if tgt_dataset is not None: | |
tgt_datasets.append(tgt_dataset) | |
logger.info( | |
"{} {} {}-{} {} examples".format( | |
data_path, split_k, src, tgt, len(src_datasets[-1]) | |
) | |
) | |
if not combine: | |
break | |
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 | |
if len(src_datasets) == 1: | |
src_dataset = src_datasets[0] | |
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None | |
else: | |
sample_ratios = [1] * len(src_datasets) | |
sample_ratios[0] = upsample_primary | |
src_dataset = ConcatDataset(src_datasets, sample_ratios) | |
if len(tgt_datasets) > 0: | |
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) | |
else: | |
tgt_dataset = None | |
if prepend_bos: | |
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") | |
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) | |
if tgt_dataset is not None: | |
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) | |
elif prepend_bos_src is not None: | |
logger.info(f"prepending src bos: {prepend_bos_src}") | |
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) | |
eos = None | |
if append_source_id: | |
src_dataset = AppendTokenDataset( | |
src_dataset, src_dict.index("[{}]".format(src)) | |
) | |
if tgt_dataset is not None: | |
tgt_dataset = AppendTokenDataset( | |
tgt_dataset, tgt_dict.index("[{}]".format(tgt)) | |
) | |
eos = tgt_dict.index("[{}]".format(tgt)) | |
align_dataset = None | |
if load_alignments: | |
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) | |
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): | |
align_dataset = data_utils.load_indexed_dataset( | |
align_path, None, dataset_impl | |
) | |
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None | |
return LanguagePairDataset( | |
src_dataset, | |
src_dataset.sizes, | |
src_dict, | |
tgt_dataset, | |
tgt_dataset_sizes, | |
tgt_dict, | |
left_pad_source=left_pad_source, | |
left_pad_target=left_pad_target, | |
align_dataset=align_dataset, | |
eos=eos, | |
num_buckets=num_buckets, | |
shuffle=shuffle, | |
pad_to_multiple=pad_to_multiple, | |
) | |
class TranslationConfig(FairseqDataclass): | |
data: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "colon separated path to data directories list, will be iterated upon during epochs " | |
"in round-robin manner; however, valid and test data are always in the first directory " | |
"to avoid the need for repeating them in all directories" | |
}, | |
) | |
source_lang: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "source language", | |
"argparse_alias": "-s", | |
}, | |
) | |
target_lang: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "target language", | |
"argparse_alias": "-t", | |
}, | |
) | |
load_alignments: bool = field( | |
default=False, metadata={"help": "load the binarized alignments"} | |
) | |
left_pad_source: bool = field( | |
default=True, metadata={"help": "pad the source on the left"} | |
) | |
left_pad_target: bool = field( | |
default=False, metadata={"help": "pad the target on the left"} | |
) | |
max_source_positions: int = field( | |
default=1024, metadata={"help": "max number of tokens in the source sequence"} | |
) | |
max_target_positions: int = field( | |
default=1024, metadata={"help": "max number of tokens in the target sequence"} | |
) | |
upsample_primary: int = field( | |
default=-1, metadata={"help": "the amount of upsample primary dataset"} | |
) | |
truncate_source: bool = field( | |
default=False, metadata={"help": "truncate source to max-source-positions"} | |
) | |
num_batch_buckets: int = field( | |
default=0, | |
metadata={ | |
"help": "if >0, then bucket source and target lengths into " | |
"N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations" | |
}, | |
) | |
train_subset: str = II("dataset.train_subset") | |
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( | |
"dataset.dataset_impl" | |
) | |
required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") | |
# options for reporting BLEU during validation | |
eval_bleu: bool = field( | |
default=False, metadata={"help": "evaluation with BLEU scores"} | |
) | |
eval_bleu_args: Optional[str] = field( | |
default="{}", | |
metadata={ | |
"help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' | |
}, | |
) | |
eval_bleu_detok: str = field( | |
default="space", | |
metadata={ | |
"help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; " | |
"use 'space' to disable detokenization; see fairseq.data.encoders for other options" | |
}, | |
) | |
eval_bleu_detok_args: Optional[str] = field( | |
default="{}", | |
metadata={"help": "args for building the tokenizer, if needed, as JSON string"}, | |
) | |
eval_tokenized_bleu: bool = field( | |
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} | |
) | |
eval_bleu_remove_bpe: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "remove BPE before computing BLEU", | |
"argparse_const": "@@ ", | |
}, | |
) | |
eval_bleu_print_samples: bool = field( | |
default=False, metadata={"help": "print sample generations during validation"} | |
) | |
class TranslationTask(FairseqTask): | |
""" | |
Translate from one (source) language to another (target) language. | |
Args: | |
src_dict (~fairseq.data.Dictionary): dictionary for the source language | |
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language | |
.. note:: | |
The translation task is compatible with :mod:`fairseq-train`, | |
:mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
""" | |
cfg: TranslationConfig | |
def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict): | |
super().__init__(cfg) | |
self.src_dict = src_dict | |
self.tgt_dict = tgt_dict | |
def setup_task(cls, cfg: TranslationConfig, **kwargs): | |
"""Setup the task (e.g., load dictionaries). | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
""" | |
paths = utils.split_paths(cfg.data) | |
assert len(paths) > 0 | |
# find language pair automatically | |
if cfg.source_lang is None or cfg.target_lang is None: | |
cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0]) | |
if cfg.source_lang is None or cfg.target_lang is None: | |
raise Exception( | |
"Could not infer language pair, please provide it explicitly" | |
) | |
# load dictionaries | |
src_dict = cls.load_dictionary( | |
os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang)) | |
) | |
tgt_dict = cls.load_dictionary( | |
os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang)) | |
) | |
assert src_dict.pad() == tgt_dict.pad() | |
assert src_dict.eos() == tgt_dict.eos() | |
assert src_dict.unk() == tgt_dict.unk() | |
logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict))) | |
logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict))) | |
return cls(cfg, src_dict, tgt_dict) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
paths = utils.split_paths(self.cfg.data) | |
assert len(paths) > 0 | |
if split != self.cfg.train_subset: | |
# if not training data set, use the first shard for valid and test | |
paths = paths[:1] | |
data_path = paths[(epoch - 1) % len(paths)] | |
# infer langcode | |
src, tgt = self.cfg.source_lang, self.cfg.target_lang | |
self.datasets[split] = load_langpair_dataset( | |
data_path, | |
split, | |
src, | |
self.src_dict, | |
tgt, | |
self.tgt_dict, | |
combine=combine, | |
dataset_impl=self.cfg.dataset_impl, | |
upsample_primary=self.cfg.upsample_primary, | |
left_pad_source=self.cfg.left_pad_source, | |
left_pad_target=self.cfg.left_pad_target, | |
max_source_positions=self.cfg.max_source_positions, | |
max_target_positions=self.cfg.max_target_positions, | |
load_alignments=self.cfg.load_alignments, | |
truncate_source=self.cfg.truncate_source, | |
num_buckets=self.cfg.num_batch_buckets, | |
shuffle=(split != "test"), | |
pad_to_multiple=self.cfg.required_seq_len_multiple, | |
) | |
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
return LanguagePairDataset( | |
src_tokens, | |
src_lengths, | |
self.source_dictionary, | |
tgt_dict=self.target_dictionary, | |
constraints=constraints, | |
) | |
def build_model(self, cfg): | |
model = super().build_model(cfg) | |
if self.cfg.eval_bleu: | |
detok_args = json.loads(self.cfg.eval_bleu_detok_args) | |
self.tokenizer = encoders.build_tokenizer( | |
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) | |
) | |
gen_args = json.loads(self.cfg.eval_bleu_args) | |
self.sequence_generator = self.build_generator( | |
[model], Namespace(**gen_args) | |
) | |
return model | |
def valid_step(self, sample, model, criterion): | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
if self.cfg.eval_bleu: | |
bleu = self._inference_with_bleu(self.sequence_generator, sample, model) | |
logging_output["_bleu_sys_len"] = bleu.sys_len | |
logging_output["_bleu_ref_len"] = bleu.ref_len | |
# we split counts into separate entries so that they can be | |
# summed efficiently across workers using fast-stat-sync | |
assert len(bleu.counts) == EVAL_BLEU_ORDER | |
for i in range(EVAL_BLEU_ORDER): | |
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] | |
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] | |
return loss, sample_size, logging_output | |
def reduce_metrics(self, logging_outputs, criterion): | |
super().reduce_metrics(logging_outputs, criterion) | |
if self.cfg.eval_bleu: | |
def sum_logs(key): | |
import torch | |
result = sum(log.get(key, 0) for log in logging_outputs) | |
if torch.is_tensor(result): | |
result = result.cpu() | |
return result | |
counts, totals = [], [] | |
for i in range(EVAL_BLEU_ORDER): | |
counts.append(sum_logs("_bleu_counts_" + str(i))) | |
totals.append(sum_logs("_bleu_totals_" + str(i))) | |
if max(totals) > 0: | |
# log counts as numpy arrays -- log_scalar will sum them correctly | |
metrics.log_scalar("_bleu_counts", np.array(counts)) | |
metrics.log_scalar("_bleu_totals", np.array(totals)) | |
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) | |
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) | |
def compute_bleu(meters): | |
import inspect | |
try: | |
from sacrebleu.metrics import BLEU | |
comp_bleu = BLEU.compute_bleu | |
except ImportError: | |
# compatibility API for sacrebleu 1.x | |
import sacrebleu | |
comp_bleu = sacrebleu.compute_bleu | |
fn_sig = inspect.getfullargspec(comp_bleu)[0] | |
if "smooth_method" in fn_sig: | |
smooth = {"smooth_method": "exp"} | |
else: | |
smooth = {"smooth": "exp"} | |
bleu = comp_bleu( | |
correct=meters["_bleu_counts"].sum, | |
total=meters["_bleu_totals"].sum, | |
sys_len=meters["_bleu_sys_len"].sum, | |
ref_len=meters["_bleu_ref_len"].sum, | |
**smooth | |
) | |
return round(bleu.score, 2) | |
metrics.log_derived("bleu", compute_bleu) | |
def max_positions(self): | |
"""Return the max sentence length allowed by the task.""" | |
return (self.cfg.max_source_positions, self.cfg.max_target_positions) | |
def source_dictionary(self): | |
"""Return the source :class:`~fairseq.data.Dictionary`.""" | |
return self.src_dict | |
def target_dictionary(self): | |
"""Return the target :class:`~fairseq.data.Dictionary`.""" | |
return self.tgt_dict | |
def _inference_with_bleu(self, generator, sample, model): | |
import sacrebleu | |
def decode(toks, escape_unk=False): | |
s = self.tgt_dict.string( | |
toks.int().cpu(), | |
self.cfg.eval_bleu_remove_bpe, | |
# The default unknown string in fairseq is `<unk>`, but | |
# this is tokenized by sacrebleu as `< unk >`, inflating | |
# BLEU scores. Instead, we use a somewhat more verbose | |
# alternative that is unlikely to appear in the real | |
# reference, but doesn't get split into multiple tokens. | |
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), | |
) | |
if self.tokenizer: | |
s = self.tokenizer.decode(s) | |
return s | |
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) | |
hyps, refs = [], [] | |
for i in range(len(gen_out)): | |
hyps.append(decode(gen_out[i][0]["tokens"])) | |
refs.append( | |
decode( | |
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), | |
escape_unk=True, # don't count <unk> as matches to the hypo | |
) | |
) | |
if self.cfg.eval_bleu_print_samples: | |
logger.info("example hypothesis: " + hyps[0]) | |
logger.info("example reference: " + refs[0]) | |
if self.cfg.eval_tokenized_bleu: | |
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") | |
else: | |
return sacrebleu.corpus_bleu(hyps, [refs]) | |