# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field import itertools import logging import os import numpy as np import torch from fairseq import metrics from fairseq.data import ( ConcatDataset, ConcatSentencesDataset, data_utils, Dictionary, IdDataset, indexed_dataset, NestedDictionaryDataset, NumSamplesDataset, NumelDataset, PrependTokenDataset, RawLabelDataset, RightPadDataset, SortDataset, TruncateDataset, TokenBlockDataset, ) from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.tasks import FairseqTask, register_task from omegaconf import II, MISSING EVAL_BLEU_ORDER = 4 TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"]) logger = logging.getLogger(__name__) @dataclass class DiscriminativeRerankingNMTConfig(FairseqDataclass): data: str = field(default=MISSING, metadata={"help": "path to data directory"}) num_data_splits: int = field( default=1, metadata={"help": "total number of data splits"} ) no_shuffle: bool = field( default=False, metadata={"help": "do not shuffle training data"} ) max_positions: int = field( default=512, metadata={"help": "number of positional embeddings to learn"} ) include_src: bool = field( default=False, metadata={"help": "include source sentence"} ) mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"}) eval_target_metric: bool = field( default=False, metadata={"help": "evaluation with the target metric during validation"}, ) target_metric: TARGET_METRIC_CHOICES = field( default="bleu", metadata={"help": "name of the target metric to optimize for"} ) train_subset: str = field( default=II("dataset.train_subset"), metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, ) seed: int = field( default=II("common.seed"), metadata={"help": "pseudo random number generator seed"}, ) class RerankerScorer(object): """Scores the target for a given (source (optional), target) input.""" def __init__(self, args, mt_beam): self.mt_beam = mt_beam @torch.no_grad() def generate(self, models, sample, **kwargs): """Score a batch of translations.""" net_input = sample["net_input"] assert len(models) == 1, "does not support model ensemble" model = models[0] bs = net_input["src_tokens"].shape[0] assert ( model.joint_classification == "none" or bs % self.mt_beam == 0 ), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})" model.eval() logits = model(**net_input) batch_out = model.sentence_forward(logits, net_input["src_tokens"]) if model.joint_classification == "sent": batch_out = model.joint_forward( batch_out.view(self.mt_beam, bs // self.mt_beam, -1) ) scores = model.classification_forward( batch_out.view(bs, 1, -1) ) # input: B x T x C return scores @register_task( "discriminative_reranking_nmt", dataclass=DiscriminativeRerankingNMTConfig ) class DiscriminativeRerankingNMTTask(FairseqTask): """ Translation rerank task. The input can be either (src, tgt) sentence pairs or tgt sentence only. """ cfg: DiscriminativeRerankingNMTConfig def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None): super().__init__(cfg) self.dictionary = data_dictionary self._max_positions = cfg.max_positions # args.tokens_per_sample = self._max_positions # self.num_classes = 1 # for model @classmethod def load_dictionary(cls, cfg, filename): """Load the dictionary from the filename""" dictionary = Dictionary.load(filename) dictionary.add_symbol("") # for loading pretrained XLMR model return dictionary @classmethod def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs): # load data dictionary (assume joint dictionary) data_path = cfg.data data_dict = cls.load_dictionary( cfg, os.path.join(data_path, "input_src/dict.txt") ) logger.info("[input] src dictionary: {} types".format(len(data_dict))) return DiscriminativeRerankingNMTTask(cfg, data_dict) def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" if self.cfg.data.endswith("1"): data_shard = (epoch - 1) % self.cfg.num_data_splits + 1 data_path = self.cfg.data[:-1] + str(data_shard) else: data_path = self.cfg.data def get_path(type, data_split): return os.path.join(data_path, str(type), data_split) def make_dataset(type, dictionary, data_split, combine): split_path = get_path(type, data_split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, combine=combine, ) return dataset def load_split(data_split, metric): input_src = None if self.cfg.include_src: input_src = make_dataset( "input_src", self.dictionary, data_split, combine=False ) assert input_src is not None, "could not find dataset: {}".format( get_path("input_src", data_split) ) input_tgt = make_dataset( "input_tgt", self.dictionary, data_split, combine=False ) assert input_tgt is not None, "could not find dataset: {}".format( get_path("input_tgt", data_split) ) label_path = f"{get_path(metric, data_split)}.{metric}" assert os.path.exists(label_path), f"could not find dataset: {label_path}" np_labels = np.loadtxt(label_path) if self.cfg.target_metric == "ter": np_labels = -np_labels label = RawLabelDataset(np_labels) return input_src, input_tgt, label src_datasets = [] tgt_datasets = [] label_datasets = [] if split == self.cfg.train_subset: for k in itertools.count(): split_k = "train" + (str(k) if k > 0 else "") prefix = os.path.join(data_path, "input_tgt", split_k) if not indexed_dataset.dataset_exists(prefix, impl=None): if k > 0: break else: raise FileNotFoundError(f"Dataset not found: {prefix}") input_src, input_tgt, label = load_split( split_k, self.cfg.target_metric ) src_datasets.append(input_src) tgt_datasets.append(input_tgt) label_datasets.append(label) else: input_src, input_tgt, label = load_split(split, self.cfg.target_metric) src_datasets.append(input_src) tgt_datasets.append(input_tgt) label_datasets.append(label) if len(tgt_datasets) == 1: input_tgt, label = tgt_datasets[0], label_datasets[0] if self.cfg.include_src: input_src = src_datasets[0] else: input_tgt = ConcatDataset(tgt_datasets) label = ConcatDataset(label_datasets) if self.cfg.include_src: input_src = ConcatDataset(src_datasets) input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) if self.cfg.include_src: input_src = PrependTokenDataset(input_src, self.dictionary.bos()) input_src = TruncateDataset(input_src, self.cfg.max_positions) src_lengths = NumelDataset(input_src, reduce=False) src_tokens = ConcatSentencesDataset(input_src, input_tgt) else: src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos()) src_lengths = NumelDataset(src_tokens, reduce=False) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": src_lengths, }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "target": label, } dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) assert len(dataset) % self.cfg.mt_beam == 0, ( "dataset size (%d) is not a multiple of beam size (%d)" % (len(dataset), self.cfg.mt_beam) ) # no need to shuffle valid/test sets if not self.cfg.no_shuffle and split == self.cfg.train_subset: # need to keep all hypothese together start_idx = np.arange(0, len(dataset), self.cfg.mt_beam) with data_utils.numpy_seed(self.cfg.seed + epoch): np.random.shuffle(start_idx) idx = np.arange(0, self.cfg.mt_beam) shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile( start_idx, (self.cfg.mt_beam, 1) ).transpose().reshape(-1) dataset = SortDataset(dataset, sort_order=[shuffle],) logger.info(f"Loaded {split} with #samples: {len(dataset)}") self.datasets[split] = dataset return self.datasets[split] def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): assert not self.cfg.include_src or len(src_tokens[0]) == 2 input_src = None if self.cfg.include_src: input_src = TokenBlockDataset( [t[0] for t in src_tokens], [l[0] for l in src_lengths], block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ) input_src = PrependTokenDataset(input_src, self.dictionary.bos()) input_src = TruncateDataset(input_src, self.cfg.max_positions) input_tgt = TokenBlockDataset( [t[-1] for t in src_tokens], [l[-1] for l in src_lengths], block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ) input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) if self.cfg.include_src: src_tokens = ConcatSentencesDataset(input_src, input_tgt) src_lengths = NumelDataset(input_src, reduce=False) else: input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos()) src_tokens = input_tgt src_lengths = NumelDataset(src_tokens, reduce=False) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": src_lengths, }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } return NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) def build_model(self, cfg: FairseqDataclass): return super().build_model(cfg) def build_generator(self, args): return RerankerScorer(args, mt_beam=self.cfg.mt_beam) def max_positions(self): return self._max_positions @property def source_dictionary(self): return self.dictionary @property def target_dictionary(self): return self.dictionary def create_dummy_batch(self, device): dummy_target = ( torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device) if not self.cfg.eval_ter else torch.zeros(self.cfg.mt_beam, 3).long().to(device) ) return { "id": torch.zeros(self.cfg.mt_beam, 1).long().to(device), "net_input": { "src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device), "src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device), }, "nsentences": 0, "ntokens": 0, "target": dummy_target, } def train_step( self, sample, model, criterion, optimizer, update_num, ignore_grad=False ): if ignore_grad and sample is None: sample = self.create_dummy_batch(model.device) return super().train_step( sample, model, criterion, optimizer, update_num, ignore_grad ) def valid_step(self, sample, model, criterion): if sample is None: sample = self.create_dummy_batch(model.device) loss, sample_size, logging_output = super().valid_step(sample, model, criterion) if not self.cfg.eval_target_metric: return loss, sample_size, logging_output scores = logging_output["scores"] if self.cfg.target_metric == "bleu": assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, ( "target does not contain enough information (" + str(sample["target"].shape[1]) + "for evaluating BLEU" ) max_id = torch.argmax(scores, dim=1) select_id = max_id + torch.arange( 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam ).to(max_id.device) bleu_data = sample["target"][select_id, 1:].sum(0).data logging_output["_bleu_sys_len"] = bleu_data[0] logging_output["_bleu_ref_len"] = bleu_data[1] for i in range(EVAL_BLEU_ORDER): logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i] logging_output["_bleu_totals_" + str(i)] = bleu_data[ 2 + EVAL_BLEU_ORDER + i ] elif self.cfg.target_metric == "ter": assert sample["target"].shape[1] == 3, ( "target does not contain enough information (" + str(sample["target"].shape[1]) + "for evaluating TER" ) max_id = torch.argmax(scores, dim=1) select_id = max_id + torch.arange( 0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam ).to(max_id.device) ter_data = sample["target"][select_id, 1:].sum(0).data logging_output["_ter_num_edits"] = -ter_data[0] logging_output["_ter_ref_len"] = -ter_data[1] return loss, sample_size, logging_output def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) if not self.cfg.eval_target_metric: return def sum_logs(key): return sum(log.get(key, 0) for log in logging_outputs) if self.cfg.target_metric == "bleu": counts, totals = [], [] for i in range(EVAL_BLEU_ORDER): counts.append(sum_logs("_bleu_counts_" + str(i))) totals.append(sum_logs("_bleu_totals_" + str(i))) if max(totals) > 0: # log counts as numpy arrays -- log_scalar will sum them correctly metrics.log_scalar("_bleu_counts", np.array(counts)) metrics.log_scalar("_bleu_totals", np.array(totals)) metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) def compute_bleu(meters): import inspect import sacrebleu fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] if "smooth_method" in fn_sig: smooth = {"smooth_method": "exp"} else: smooth = {"smooth": "exp"} bleu = sacrebleu.compute_bleu( correct=meters["_bleu_counts"].sum, total=meters["_bleu_totals"].sum, sys_len=meters["_bleu_sys_len"].sum, ref_len=meters["_bleu_ref_len"].sum, **smooth, ) return round(bleu.score, 2) metrics.log_derived("bleu", compute_bleu) elif self.cfg.target_metric == "ter": num_edits = sum_logs("_ter_num_edits") ref_len = sum_logs("_ter_ref_len") if ref_len > 0: metrics.log_scalar("_ter_num_edits", num_edits) metrics.log_scalar("_ter_ref_len", ref_len) def compute_ter(meters): score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum return round(score.item(), 2) metrics.log_derived("ter", compute_ter)