""" Script for decoding summarization models available through Huggingface Transformers. Usage with Huggingface Datasets: python generation.py --model --data_path Usage with custom datasets in JSONL format: python generation.py --model --dataset --split """ #!/usr/bin/env python # coding: utf-8 import argparse import json import os import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from datasets import load_dataset from tqdm import tqdm BATCH_SIZE = 8 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn' BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum' PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail' PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum' PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom' PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news' MODEL_CHECKPOINTS = { 'bart-xsum': BART_XSUM_CHECKPOINT, 'bart-cnndm': BART_CNNDM_CHECKPOINT, 'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT, 'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT, 'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT, 'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT } class JSONDataset(torch.utils.data.Dataset): def __init__(self, data_path): super(JSONDataset, self).__init__() with open(data_path) as fd: self.data = [json.loads(line) for line in fd] def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def preprocess_data(raw_data, dataset): """ Unify format of Huggingface Datastes :param raw_data: loaded data :param dataset: name of dataset """ if dataset == 'xsum': raw_data['article'] = raw_data['document'] raw_data['target'] = raw_data['summary'] del raw_data['document'] del raw_data['summary'] elif dataset == 'cnndm': raw_data['target'] = raw_data['highlights'] del raw_data['highlights'] elif dataset == 'gigaword': raw_data['article'] = raw_data['document'] raw_data['target'] = raw_data['summary'] del raw_data['document'] del raw_data['summary'] return raw_data def postprocess_data(raw_data, decoded): """ Remove generation artifacts and postprocess outputs :param raw_data: loaded data :param decoded: model outputs """ raw_data['target'] = [x.replace('\n', ' ') for x in raw_data['target']] raw_data['decoded'] = [x.replace('', ' ') for x in decoded] return [dict(zip(raw_data, t)) for t in zip(*raw_data.values())] if __name__ == '__main__': parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--model', type=str, required=True, choices=['bart-xsum', 'bart-cnndm', 'pegasus-xsum', 'pegasus-cnndm', 'pegasus-newsroom', 'pegasus-multinews']) parser.add_argument('--data_path', type=str) parser.add_argument('--dataset', type=str, choices=['xsum', 'cnndm', 'gigaword']) parser.add_argument('--split', type=str, choices=['train', 'validation', 'test']) args = parser.parse_args() if args.dataset and not args.split: raise RuntimeError('If `dataset` flag is specified `split` must also be provided.') if args.data_path: args.dataset = os.path.splitext(os.path.basename(args.data_path))[0] args.split = 'user' # Load models & data model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINTS[args.model]).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINTS[args.model]) if not args.data_path: if args.dataset == 'cnndm': dataset = load_dataset('cnn_dailymail', '3.0.0', split=args.split) elif args.dataset =='xsum': dataset = load_dataset('xsum', split=args.split) elif args.dataset =='gigaword': dataset = load_dataset('gigaword', split=args.split) else: dataset = JSONDataset(args.data_path) dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE) # Run validation filename = '%s.%s.%s.results' % (args.model.replace("/", "-"), args.dataset, args.split) fd_out = open(filename, 'w') results = [] model.eval() with torch.no_grad(): for raw_data in tqdm(dataloader): raw_data = preprocess_data(raw_data, args.dataset) batch = tokenizer(raw_data["article"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE) summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask) decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) result = postprocess_data(raw_data, decoded) results.extend(result) for example in result: fd_out.write(json.dumps(example) + '\n')