#! /usr/bin/python3 import argparse import logging import os import sys from collections import namedtuple import torch from modeling_bertabs import BertAbs, build_predictor from torch.utils.data import DataLoader, SequentialSampler from tqdm import tqdm from transformers import BertTokenizer from .utils_summarization import ( CNNDMDataset, build_mask, compute_token_type_ids, encode_for_summarization, truncate_or_pad, ) logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]) def evaluate(args): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) model = BertAbs.from_pretrained("remi/bertabs-finetuned-extractive-abstractive-summarization") model.to(args.device) model.eval() symbols = { "BOS": tokenizer.vocab["[unused0]"], "EOS": tokenizer.vocab["[unused1]"], "PAD": tokenizer.vocab["[PAD]"], } if args.compute_rouge: reference_summaries = [] generated_summaries = [] import nltk import rouge nltk.download("punkt") rouge_evaluator = rouge.Rouge( metrics=["rouge-n", "rouge-l"], max_n=2, limit_length=True, length_limit=args.beam_size, length_limit_type="words", apply_avg=True, apply_best=False, alpha=0.5, # Default F1_score weight_factor=1.2, stemming=True, ) # these (unused) arguments are defined to keep the compatibility # with the legacy code and will be deleted in a next iteration. args.result_path = "" args.temp_dir = "" data_iterator = build_data_iterator(args, tokenizer) predictor = build_predictor(args, tokenizer, symbols, model) logger.info("***** Running evaluation *****") logger.info(" Number examples = %d", len(data_iterator.dataset)) logger.info(" Batch size = %d", args.batch_size) logger.info("") logger.info("***** Beam Search parameters *****") logger.info(" Beam size = %d", args.beam_size) logger.info(" Minimum length = %d", args.min_length) logger.info(" Maximum length = %d", args.max_length) logger.info(" Alpha (length penalty) = %.2f", args.alpha) logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT")) for batch in tqdm(data_iterator): batch_data = predictor.translate_batch(batch) translations = predictor.from_batch(batch_data) summaries = [format_summary(t) for t in translations] save_summaries(summaries, args.summaries_output_dir, batch.document_names) if args.compute_rouge: reference_summaries += batch.tgt_str generated_summaries += summaries if args.compute_rouge: scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries) str_scores = format_rouge_scores(scores) save_rouge_scores(str_scores) print(str_scores) def save_summaries(summaries, path, original_document_name): """Write the summaries in fies that are prefixed by the original files' name with the `_summary` appended. Attributes: original_document_names: List[string] Name of the document that was summarized. path: string Path were the summaries will be written summaries: List[string] The summaries that we produced. """ for summary, document_name in zip(summaries, original_document_name): # Prepare the summary file's name if "." in document_name: bare_document_name = ".".join(document_name.split(".")[:-1]) extension = document_name.split(".")[-1] name = bare_document_name + "_summary." + extension else: name = document_name + "_summary" file_path = os.path.join(path, name) with open(file_path, "w") as output: output.write(summary) def format_summary(translation): """Transforms the output of the `from_batch` function into nicely formatted summaries. """ raw_summary, _, _ = translation summary = ( raw_summary.replace("[unused0]", "") .replace("[unused3]", "") .replace("[PAD]", "") .replace("[unused1]", "") .replace(r" +", " ") .replace(" [unused2] ", ". ") .replace("[unused2]", "") .strip() ) return summary def format_rouge_scores(scores): return """\n ****** ROUGE SCORES ****** ** ROUGE 1 F1 >> {:.3f} Precision >> {:.3f} Recall >> {:.3f} ** ROUGE 2 F1 >> {:.3f} Precision >> {:.3f} Recall >> {:.3f} ** ROUGE L F1 >> {:.3f} Precision >> {:.3f} Recall >> {:.3f}""".format( scores["rouge-1"]["f"], scores["rouge-1"]["p"], scores["rouge-1"]["r"], scores["rouge-2"]["f"], scores["rouge-2"]["p"], scores["rouge-2"]["r"], scores["rouge-l"]["f"], scores["rouge-l"]["p"], scores["rouge-l"]["r"], ) def save_rouge_scores(str_scores): with open("rouge_scores.txt", "w") as output: output.write(str_scores) # # LOAD the dataset # def build_data_iterator(args, tokenizer): dataset = load_and_cache_examples(args, tokenizer) sampler = SequentialSampler(dataset) def collate_fn(data): return collate(data, tokenizer, block_size=512, device=args.device) iterator = DataLoader( dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn, ) return iterator def load_and_cache_examples(args, tokenizer): dataset = CNNDMDataset(args.documents_dir) return dataset def collate(data, tokenizer, block_size, device): """Collate formats the data passed to the data loader. In particular we tokenize the data batch after batch to avoid keeping them all in memory. We output the data as a namedtuple to fit the original BertAbs's API. """ data = [x for x in data if not len(x[1]) == 0] # remove empty_files names = [name for name, _, _ in data] summaries = [" ".join(summary_list) for _, _, summary_list in data] encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data] encoded_stories = torch.tensor( [truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text] ) encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id) encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id) batch = Batch( document_names=names, batch_size=len(encoded_stories), src=encoded_stories.to(device), segs=encoder_token_type_ids.to(device), mask_src=encoder_mask.to(device), tgt_str=summaries, ) return batch def decode_summary(summary_tokens, tokenizer): """Decode the summary and return it in a format suitable for evaluation. """ summary_tokens = summary_tokens.to("cpu").numpy() summary = tokenizer.decode(summary_tokens) sentences = summary.split(".") sentences = [s + "." for s in sentences] return sentences def main(): """The main function defines the interface with the users.""" parser = argparse.ArgumentParser() parser.add_argument( "--documents_dir", default=None, type=str, required=True, help="The folder where the documents to summarize are located.", ) parser.add_argument( "--summaries_output_dir", default=None, type=str, required=False, help="The folder in wich the summaries should be written. Defaults to the folder where the documents are", ) parser.add_argument( "--compute_rouge", default=False, type=bool, required=False, help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.", ) # EVALUATION options parser.add_argument( "--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.", ) parser.add_argument( "--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.", ) # BEAM SEARCH arguments parser.add_argument( "--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.", ) parser.add_argument( "--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.", ) parser.add_argument( "--beam_size", default=5, type=int, help="The number of beams to start with for each example.", ) parser.add_argument( "--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.", ) parser.add_argument( "--block_trigram", default=True, type=bool, help="Whether to block the existence of repeating trigrams in the text generated by beam search.", ) args = parser.parse_args() # Select device (distibuted not available) args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") # Check the existence of directories if not args.summaries_output_dir: args.summaries_output_dir = args.documents_dir if not documents_dir_is_valid(args.documents_dir): raise FileNotFoundError( "We could not find the directory you specified for the documents to summarize, or it was empty. Please" " specify a valid path." ) os.makedirs(args.summaries_output_dir, exist_ok=True) evaluate(args) def documents_dir_is_valid(path): if not os.path.exists(path): return False file_list = os.listdir(path) if len(file_list) == 0: return False return True if __name__ == "__main__": main()