Spaces:
Runtime error
Runtime error
# Copyright 2022 The OFA-Sys Team. | |
# All rights reserved. | |
# This source code is licensed under the Apache 2.0 license | |
# found in the LICENSE file in the root directory. | |
from dataclasses import dataclass, field | |
import json | |
import logging | |
import string | |
from typing import Optional | |
from argparse import Namespace | |
from fairseq import metrics | |
from fairseq.tasks import register_task | |
from fairseq.data import encoders | |
from tasks.ofa_task import OFATask, OFAConfig | |
from data.nlg_data.summary_dataset import SummaryDataset | |
from data.file_dataset import FileDataset | |
from datasets import load_metric | |
logger = logging.getLogger(__name__) | |
_tok_dict = {"(": "-lrb-", ")": "-rrb-", | |
"[": "-lsb-", "]": "-rsb-", | |
"{": "-lcb-", "}": "-rcb-", | |
"[UNK]": "UNK", '&': '&', '<': '<', '>': '>'} | |
def _is_digit(w): | |
for ch in w: | |
if not(ch.isdigit() or ch == ','): | |
return False | |
return True | |
def fix_tokenization(text): | |
input_tokens = text.split() | |
output_tokens = [] | |
has_left_quote = False | |
has_left_single_quote = False | |
i = 0 | |
prev_dash = False | |
while i < len(input_tokens): | |
tok = input_tokens[i] | |
flag_prev_dash = False | |
if tok in _tok_dict.keys(): | |
output_tokens.append(_tok_dict[tok]) | |
i += 1 | |
elif tok == "\"": | |
if has_left_quote: | |
output_tokens.append("''") | |
else: | |
output_tokens.append("``") | |
has_left_quote = not has_left_quote | |
i += 1 | |
elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t": | |
output_tokens[-1] = output_tokens[-1][:-1] | |
output_tokens.append("n't") | |
i += 2 | |
elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): | |
output_tokens.append("'"+input_tokens[i + 1]) | |
i += 2 | |
elif tok == "'": | |
if has_left_single_quote: | |
output_tokens.append("'") | |
else: | |
output_tokens.append("`") | |
has_left_single_quote = not has_left_single_quote | |
i += 1 | |
elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": | |
output_tokens.append("...") | |
i += 3 | |
elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]): | |
# $ 3 , 000 -> $ 3,000 | |
output_tokens[-1] += ','+input_tokens[i + 1] | |
i += 2 | |
elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit(): | |
# 3 . 03 -> $ 3.03 | |
output_tokens[-1] += '.'+input_tokens[i + 1] | |
i += 2 | |
elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.': | |
# U . N . -> U.N. | |
k = i+3 | |
while k+2 < len(input_tokens): | |
if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': | |
k += 2 | |
else: | |
break | |
output_tokens[-1] += ''.join(input_tokens[i:k]) | |
i += 2 | |
elif tok == "-": | |
if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": | |
output_tokens.append("--") | |
i += 2 | |
elif i == len(input_tokens) - 1 or i == 0: | |
output_tokens.append("-") | |
i += 1 | |
elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation: | |
output_tokens[-1] += "-" | |
i += 1 | |
flag_prev_dash = True | |
else: | |
output_tokens.append("-") | |
i += 1 | |
elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: | |
output_tokens[-1] += tok | |
i += 1 | |
else: | |
output_tokens.append(tok) | |
i += 1 | |
prev_dash = flag_prev_dash | |
return " ".join(output_tokens) | |
class GigawordConfig(OFAConfig): | |
# options for reporting Rouge during validation | |
eval_rouge: bool = field( | |
default=False, metadata={"help": "evaluation with rouge scores"} | |
) | |
eval_args: Optional[str] = field( | |
default='{}', | |
metadata={ | |
"help": 'generation args for BLUE or CIDEr scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' | |
}, | |
) | |
eval_detok: str = field( | |
default="space", | |
metadata={ | |
"help": "detokenize before computing BLEU or CIDEr (e.g., 'moses'); " | |
"required if using --eval-bleu or --eval-cider; " | |
"use 'space' to disable detokenization; see fairseq.data.encoders for other options" | |
}, | |
) | |
eval_detok_args: Optional[str] = field( | |
default="{}", | |
metadata={"help": "args for building the tokenizer, if needed, as JSON string"}, | |
) | |
eval_print_samples: bool = field( | |
default=False, metadata={"help": "print sample generations during validation"} | |
) | |
noise_ratio: float = field( | |
default=0.0, metadata={"help": "noise ratio for prev output"} | |
) | |
class GigawordTask(OFATask): | |
def __init__(self, cfg: GigawordConfig, src_dict, tgt_dict): | |
super().__init__(cfg, src_dict, tgt_dict) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
paths = self.cfg.data.split(',') | |
assert len(paths) > 0 | |
if split == 'train': | |
file_path = paths[(epoch - 1) % (len(paths) - 1)] | |
else: | |
file_path = paths[-1] | |
dataset = FileDataset(file_path, self.cfg.selected_cols) | |
self.datasets[split] = SummaryDataset( | |
split, | |
dataset, | |
self.bpe, | |
self.src_dict, | |
self.tgt_dict, | |
code_dict_size=self.cfg.code_dict_size, | |
num_bins=self.cfg.num_bins, | |
max_src_length=self.cfg.max_src_length, | |
max_tgt_length=self.cfg.max_tgt_length, | |
noise_ratio=self.cfg.noise_ratio | |
) | |
def build_model(self, cfg): | |
model = super().build_model(cfg) | |
if self.cfg.eval_rouge: | |
detok_args = json.loads(self.cfg.eval_detok_args) | |
self.tokenizer = encoders.build_tokenizer( | |
Namespace(tokenizer=self.cfg.eval_detok, **detok_args) | |
) | |
gen_args = json.loads(self.cfg.eval_args) | |
self.sequence_generator = self.build_generator( | |
[model], Namespace(**gen_args) | |
) | |
self.metric = load_metric('../../utils/rouge.py') | |
return model | |
def valid_step(self, sample, model, criterion): | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
if self.cfg.eval_rouge: | |
hyps, refs = self._inference(self.sequence_generator, sample, model) | |
result = self.metric.compute(predictions=hyps, references=refs, use_agregator=False, use_stemmer=True) | |
result_recall = {key: sum([item.recall for item in value]) * 100 for key, value in result.items()} | |
result_f1 = {key: sum([item.fmeasure for item in value]) * 100 for key, value in result.items()} | |
logging_output['_rouge1_recall_sum'] = result_recall['rouge1'] | |
logging_output['_rouge2_recall_sum'] = result_recall['rouge2'] | |
logging_output['_rougeL_recall_sum'] = result_recall['rougeL'] | |
logging_output['_rouge1_f1_sum'] = result_f1['rouge1'] | |
logging_output['_rouge2_f1_sum'] = result_f1['rouge2'] | |
logging_output['_rougeL_f1_sum'] = result_f1['rougeL'] | |
logging_output['_rouge_cnt'] = len(hyps) | |
return loss, sample_size, logging_output | |
def reduce_metrics(self, logging_outputs, criterion): | |
super().reduce_metrics(logging_outputs, criterion) | |
def sum_logs(key): | |
import torch | |
result = sum(log.get(key, 0) for log in logging_outputs) | |
if torch.is_tensor(result): | |
result = result.cpu() | |
return result | |
if sum_logs("_rouge_cnt") > 0: | |
metrics.log_scalar("_rouge1_recall_sum", sum_logs("_rouge1_recall_sum")) | |
metrics.log_scalar("_rouge2_recall_sum", sum_logs("_rouge2_recall_sum")) | |
metrics.log_scalar("_rougeL_recall_sum", sum_logs("_rougeL_recall_sum")) | |
metrics.log_scalar("_rouge1_f1_sum", sum_logs("_rouge1_f1_sum")) | |
metrics.log_scalar("_rouge2_f1_sum", sum_logs("_rouge2_f1_sum")) | |
metrics.log_scalar("_rougeL_f1_sum", sum_logs("_rougeL_f1_sum")) | |
metrics.log_scalar("_rouge_cnt", sum_logs("_rouge_cnt")) | |
metrics.log_derived("rouge1_recall", lambda x: x["_rouge1_recall_sum"].sum / x["_rouge_cnt"].sum) | |
metrics.log_derived("rouge2_recall", lambda x: x["_rouge2_recall_sum"].sum / x["_rouge_cnt"].sum) | |
metrics.log_derived("rougeL_recall", lambda x: x["_rougeL_recall_sum"].sum / x["_rouge_cnt"].sum) | |
metrics.log_derived("rouge1_f1", lambda x: x["_rouge1_f1_sum"].sum / x["_rouge_cnt"].sum) | |
metrics.log_derived("rouge2_f1", lambda x: x["_rouge2_f1_sum"].sum / x["_rouge_cnt"].sum) | |
metrics.log_derived("rougeL_f1", lambda x: x["_rougeL_f1_sum"].sum / x["_rouge_cnt"].sum) | |
def _inference(self, generator, sample, model): | |
def decode(toks): | |
s = self.tgt_dict.string(toks.int().cpu()) | |
if self.bpe: | |
s = self.bpe.decode(s) | |
if self.tokenizer: | |
s = self.tokenizer.decode(s) | |
return s | |
gen_out = self.inference_step(generator, [model], sample) | |
hyps, refs = [], [] | |
for i in range(len(gen_out)): | |
hyp = decode(gen_out[i][0]["tokens"]).lower().strip() | |
hyp = fix_tokenization(hyp).replace('<unk>', ' unk').replace('1', '#') | |
ref = sample["target_strs"][i] | |
hyps.append(hyp) | |
refs.append(ref) | |
if self.cfg.eval_print_samples: | |
logger.info("example hypothesis: " + hyps[0]) | |
logger.info("example reference: " + refs[0]) | |
return hyps, refs | |