File size: 2,510 Bytes
0379fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Script for joining dataset of documents/reference summaries with generated summaries (likely from generate.py).

Usage with custom datasets in JSONL format:
python join.py --data_path <path to data in jsonl format> --generation_paths <paths to generated predictions>  --output_path <path to output file>

Optionally specify --model_names to override default model names.

"""
# !/usr/bin/env python
# coding: utf-8

import argparse
import json
import os
from pathlib import Path

import torch
from tqdm import tqdm

BATCH_SIZE = 8


class JSONDataset(torch.utils.data.Dataset):
    def __init__(self, data_path):
        super(JSONDataset, self).__init__()

        with open(data_path) as fd:
            self.data = [json.loads(line) for line in fd]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--generation_paths', type=str, nargs="+", required=True)
    parser.add_argument('--output_path', type=str, required=True)
    parser.add_argument('--model_names', type=str, nargs="+")
    args = parser.parse_args()

    if args.model_names and len(args.generation_paths) != len(args.model_names):
        raise ValueError('Length of args.generation_paths must equal length of args.model_names')

    if args.model_names:
        model_names = args.model_names
    else:
        model_names = [Path(p).name.split(".")[0] for p in args.generation_paths]

    args.dataset = os.path.splitext(os.path.basename(args.data_path))[0]
    args.split = 'user'

    # Load data

    dataset = JSONDataset(args.data_path)

    # Join files and write out single jsonl dataset

    generation_files = [open(fname) for fname in args.generation_paths]

    with open(args.output_path, 'w') as outp:
        for row in tqdm(zip(dataset, *generation_files)):
            # Process each original data record in parallel with generation(s) of the model(s)
            result = {}
            data = row[0]
            generations = row[1:]
            result['summary:reference'] = data['summary:reference']
            result['document'] = data['document']
            for model_name, gen in zip(model_names, generations):
                result[f'summary:{model_name}'] = gen
            outp.write(
                json.dumps(result) + '\n'
            )

    for file in generation_files:
        file.close()