File size: 3,697 Bytes
0379fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Script for decoding summarization models available through Huggingface Transformers.

To use with one of the 6 standard models:
python generation.py --model <model abbreviation> --data_path <path to data in jsonl format>
    where model abbreviation is one of: bart-xsum, bart-cnndm, pegasus-xsum, pegasus-cnndm, pegasus-newsroom,
    pegasus-multinews:

To use with arbitrary model:
python generation.py --model_name_or_path <Huggingface model name or local path> --data_path <path to data in jsonl format>

"""
# !/usr/bin/env python
# coding: utf-8

import argparse
import json
import os

import torch
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn'
BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum'
PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail'
PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum'
PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom'
PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news'

MODEL_CHECKPOINTS = {
    'bart-xsum': BART_XSUM_CHECKPOINT,
    'bart-cnndm': BART_CNNDM_CHECKPOINT,
    'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT,
    'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT,
    'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT,
    'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT
}


class JSONDataset(torch.utils.data.Dataset):
    def __init__(self, data_path):
        super(JSONDataset, self).__init__()

        with open(data_path) as fd:
            self.data = [json.loads(line) for line in fd]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def postprocess_data(decoded):
    """
    Remove generation artifacts and postprocess outputs

    :param decoded: model outputs
    """
    return [x.replace('<n>', ' ') for x in decoded]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model', type=str)
    parser.add_argument('--model_name_or_path', type=str)
    parser.add_argument('--data_path', type=str)
    args = parser.parse_args()

    if not (args.model or args.model_name_or_path):
        raise ValueError('Model is required')

    if args.model and args.model_name_or_path:
        raise ValueError('Specify model or model_name_or_path but not both')

    # Load models & data
    if args.model:
        model_name_or_path = MODEL_CHECKPOINTS[args.model]
        file_model_name = args.model
    else:
        model_name_or_path = args.model_name_or_path
        file_model_name = model_name_or_path.replace("/", "-")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    dataset = JSONDataset(args.data_path)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

    # Write out dataset
    file_dataset_name = os.path.splitext(os.path.basename(args.data_path))[0]
    filename = f'{file_model_name}.{file_dataset_name}.predictions'
    fd_out = open(filename, 'w')

    model.eval()
    with torch.no_grad():
        for raw_data in tqdm(dataloader):
            batch = tokenizer(raw_data["document"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
            summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask)
            decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            for example in postprocess_data(decoded):
                fd_out.write(example + '\n')