summvis / generation.py
cbensimon's picture
cbensimon HF staff
Initial commit
6124176 unverified
"""
Script for decoding summarization models available through Huggingface Transformers.
Usage with Huggingface Datasets:
python generation.py --model <model name> --data_path <path to data in jsonl format>
Usage with custom datasets in JSONL format:
python generation.py --model <model name> --dataset <dataset name> --split <data split>
"""
#!/usr/bin/env python
# coding: utf-8
import argparse
import json
import os
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn'
BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum'
PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail'
PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum'
PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom'
PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news'
MODEL_CHECKPOINTS = {
'bart-xsum': BART_XSUM_CHECKPOINT,
'bart-cnndm': BART_CNNDM_CHECKPOINT,
'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT,
'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT,
'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT,
'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT
}
class JSONDataset(torch.utils.data.Dataset):
def __init__(self, data_path):
super(JSONDataset, self).__init__()
with open(data_path) as fd:
self.data = [json.loads(line) for line in fd]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def preprocess_data(raw_data, dataset):
"""
Unify format of Huggingface Datastes
:param raw_data: loaded data
:param dataset: name of dataset
"""
if dataset == 'xsum':
raw_data['article'] = raw_data['document']
raw_data['target'] = raw_data['summary']
del raw_data['document']
del raw_data['summary']
elif dataset == 'cnndm':
raw_data['target'] = raw_data['highlights']
del raw_data['highlights']
elif dataset == 'gigaword':
raw_data['article'] = raw_data['document']
raw_data['target'] = raw_data['summary']
del raw_data['document']
del raw_data['summary']
return raw_data
def postprocess_data(raw_data, decoded):
"""
Remove generation artifacts and postprocess outputs
:param raw_data: loaded data
:param decoded: model outputs
"""
raw_data['target'] = [x.replace('\n', ' ') for x in raw_data['target']]
raw_data['decoded'] = [x.replace('<n>', ' ') for x in decoded]
return [dict(zip(raw_data, t)) for t in zip(*raw_data.values())]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--model', type=str, required=True, choices=['bart-xsum', 'bart-cnndm', 'pegasus-xsum', 'pegasus-cnndm', 'pegasus-newsroom', 'pegasus-multinews'])
parser.add_argument('--data_path', type=str)
parser.add_argument('--dataset', type=str, choices=['xsum', 'cnndm', 'gigaword'])
parser.add_argument('--split', type=str, choices=['train', 'validation', 'test'])
args = parser.parse_args()
if args.dataset and not args.split:
raise RuntimeError('If `dataset` flag is specified `split` must also be provided.')
if args.data_path:
args.dataset = os.path.splitext(os.path.basename(args.data_path))[0]
args.split = 'user'
# Load models & data
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINTS[args.model]).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINTS[args.model])
if not args.data_path:
if args.dataset == 'cnndm':
dataset = load_dataset('cnn_dailymail', '3.0.0', split=args.split)
elif args.dataset =='xsum':
dataset = load_dataset('xsum', split=args.split)
elif args.dataset =='gigaword':
dataset = load_dataset('gigaword', split=args.split)
else:
dataset = JSONDataset(args.data_path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)
# Run validation
filename = '%s.%s.%s.results' % (args.model.replace("/", "-"), args.dataset, args.split)
fd_out = open(filename, 'w')
results = []
model.eval()
with torch.no_grad():
for raw_data in tqdm(dataloader):
raw_data = preprocess_data(raw_data, args.dataset)
batch = tokenizer(raw_data["article"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask)
decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
result = postprocess_data(raw_data, decoded)
results.extend(result)
for example in result:
fd_out.write(json.dumps(example) + '\n')