File size: 2,480 Bytes
500fbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os.path as osp
import argparse
import json
from data import Tasks, DATASET_TASK_DICT
from utils import preprocess_path


def process_result(entry, name, task):
    processed = {
        'name': name,
        'task': str(task),
    }

    if task == Tasks.EXTRACTIVE_QUESTION_ANSWERING:
        key = 'em,none' if name == 'mkqa_tr' else 'exact,none'
        scale = 0.01 if name != 'mkqa_tr' else 1
        processed['exact_match'] = scale * entry[key]
        processed['f1'] = scale * entry['f1,none']
    elif task == Tasks.SUMMARIZATION:
        processed['rouge1'] = entry['rouge1,none']
        processed['rouge2'] = entry['rouge2,none']
        processed['rougeL'] = entry['rougeL,none']
    elif task in (
        Tasks.MULTIPLE_CHOICE,
        Tasks.NATURAL_LANGUAGE_INFERENCE,
        Tasks.TEXT_CLASSIFICATION,
    ):
        processed['acc'] = entry['acc,none']
        processed['acc_norm'] = entry.get('acc_norm,none', processed['acc'])
    elif task == Tasks.MACHINE_TRANSLATION:
        processed['wer'] = entry['wer,none']
        processed['bleu'] = entry['bleu,none']
    elif task == Tasks.GRAMMATICAL_ERROR_CORRECTION:
        processed['exact_match'] = entry['exact_match,none']
    
    return processed


def main():
    parser = argparse.ArgumentParser(description='Results file formatter.')
    parser.add_argument('-i', '--input-file', type=str, help='Input JSON file for the results.')
    parser.add_argument('-o', '--output-file', type=str, help='Output JSON file for the formatted results.')
    args = parser.parse_args()

    with open(preprocess_path(args.input_file)) as f:
        raw_data = json.load(f)

    # first, get model args
    model_args = raw_data['config']['model_args'].split(',')
    model_args = dict([tuple(pair.split('=')) for pair in model_args])
    processed = dict()
    model_args['model'] = model_args.pop('pretrained')
    processed['model'] = model_args
    processed['model']['api'] = raw_data['config']['model']

    # then, process results
    results = raw_data['results']
    processed['results'] = list()
    for dataset, entry in results.items():
        if dataset not in DATASET_TASK_DICT.keys():
            continue
        task = DATASET_TASK_DICT[dataset]
        processed['results'].append(process_result(entry, dataset, task))
    
    with open(preprocess_path(args.output_file), 'w') as f:
        json.dump(processed, f, indent=4)
    
    print('done')


if __name__ == '__main__':
    main()