File size: 4,992 Bytes
6124176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Script for decoding summarization models available through Huggingface Transformers.

Usage with Huggingface Datasets:
python generation.py --model <model name> --data_path <path to data in jsonl format>

Usage with custom datasets in JSONL format:
python generation.py --model <model name> --dataset <dataset name> --split <data split>
"""
#!/usr/bin/env python
# coding: utf-8

import argparse
import json
import os

import torch

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn'
BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum'
PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail'
PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum'
PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom'
PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news'

MODEL_CHECKPOINTS = {
    'bart-xsum': BART_XSUM_CHECKPOINT,
    'bart-cnndm': BART_CNNDM_CHECKPOINT,
    'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT,
    'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT,
    'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT,
    'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT
}


class JSONDataset(torch.utils.data.Dataset):
    def __init__(self, data_path):
        super(JSONDataset, self).__init__()
        
        with open(data_path) as fd:
            self.data = [json.loads(line) for line in fd]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def preprocess_data(raw_data, dataset):
    """
    Unify format of Huggingface Datastes

    :param raw_data: loaded data
    :param dataset: name of dataset
    """
    if dataset == 'xsum':
        raw_data['article'] = raw_data['document']
        raw_data['target'] = raw_data['summary']
        del raw_data['document']
        del raw_data['summary']
    elif dataset == 'cnndm':
        raw_data['target'] = raw_data['highlights']
        del raw_data['highlights']
    elif dataset == 'gigaword':
        raw_data['article'] = raw_data['document']
        raw_data['target'] = raw_data['summary']
        del raw_data['document']
        del raw_data['summary']

    return raw_data


def postprocess_data(raw_data, decoded):
    """
    Remove generation artifacts and postprocess outputs

    :param raw_data: loaded data
    :param decoded: model outputs
    """
    raw_data['target'] = [x.replace('\n', ' ') for x in raw_data['target']]
    raw_data['decoded'] = [x.replace('<n>', ' ') for x in decoded]

    return [dict(zip(raw_data, t)) for t in zip(*raw_data.values())]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model', type=str, required=True, choices=['bart-xsum', 'bart-cnndm', 'pegasus-xsum', 'pegasus-cnndm', 'pegasus-newsroom', 'pegasus-multinews'])
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--dataset', type=str, choices=['xsum', 'cnndm', 'gigaword'])
    parser.add_argument('--split', type=str, choices=['train', 'validation', 'test'])
    args = parser.parse_args()

    if args.dataset and not args.split:
        raise RuntimeError('If `dataset` flag is specified `split` must also be provided.')

    if args.data_path:
        args.dataset = os.path.splitext(os.path.basename(args.data_path))[0]
        args.split = 'user'

    # Load models & data
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINTS[args.model]).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINTS[args.model])

    if not args.data_path:
        if args.dataset == 'cnndm':
            dataset = load_dataset('cnn_dailymail', '3.0.0', split=args.split)
        elif args.dataset =='xsum':
            dataset = load_dataset('xsum', split=args.split)
        elif args.dataset =='gigaword':
            dataset = load_dataset('gigaword', split=args.split)
    else:
        dataset = JSONDataset(args.data_path)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

    # Run validation
    filename = '%s.%s.%s.results' % (args.model.replace("/", "-"), args.dataset, args.split)
    fd_out = open(filename, 'w')

    results = []
    model.eval()
    with torch.no_grad():
        for raw_data in tqdm(dataloader):
            raw_data = preprocess_data(raw_data, args.dataset)
            batch = tokenizer(raw_data["article"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
            summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask)

            decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            result = postprocess_data(raw_data, decoded)
            results.extend(result)

            for example in result:
                fd_out.write(json.dumps(example) + '\n')