File size: 4,268 Bytes
b6e5241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import re
import os
import sys
import json
import numpy as np
import editdistance
from tqdm import tqdm
from collections import Counter

sys.path.insert(0, '.')
from utils.text_utils import detokenize_sent
from critic.critic import run_gpt2, gpt2_critic

def load_data():
    data_path = 'eval_critic/eval_data.jsonl'
    good_sents, bad_sents = [], []
    for line in open(data_path):
        obj = json.loads(line)
        good_sents.append(obj['good'])
        bad_sents.append(obj['bad'])
    return good_sents, bad_sents

good_sents, bad_sents = load_data()



def get_logps(sents):
    final = []
    for start in tqdm(range(0, len(sents), 100)):
        sents_sub = sents[start: start+100]
        sents_sub_detok = [detokenize_sent(sent) for sent in sents_sub]
        logps = run_gpt2(sents_sub_detok)
        assert logps is not None
        for i in range(len(sents_sub)):
            final.append({'sent': sents_sub[i], 'sent_detok': sents_sub_detok[i], 'logp': float(logps[i])})
    return final

def evaluate_logp():
    """
    Check whether log p(bad_sent) < log p(good_sent)
    """
    good_logps = get_logps(good_sents)
    bad_logps  = get_logps(bad_sents)
    accs = []
    for good, bad in zip(good_logps, bad_logps):
        accs.append(int(bad['logp'] < good['logp']))
    avg_acc = float(sum(accs))/len(accs)
    print (f'log p(bad) < log p(good)? {sum(accs)} / {len(accs)} = {avg_acc:.3f}')
    return good_logps, bad_logps

good_logps, bad_logps = evaluate_logp()
# log p(bad) < log p(good)? 555 / 586 = 0.947


def compute_metrics(good_accs, bad_accs):
    goodP = float(sum(good_accs))/(len(bad_accs)-sum(bad_accs)+sum(good_accs))
    goodR = float(sum(good_accs))/len(good_accs)
    goodF05 = (1+0.5**2) * float(goodP * goodR)/((0.5**2 * goodP) + goodR)
    badP = float(sum(bad_accs))/(len(good_accs)-sum(good_accs)+sum(bad_accs))
    badR = float(sum(bad_accs))/len(bad_accs)
    badF05 = (1+0.5**2) * float(badP * badR)/((0.5**2 * badP) + badR)
    print (f'  Good precision = {sum(good_accs)} / {(len(bad_accs)-sum(bad_accs)+sum(good_accs))} = {goodP:.3f}')
    print (f'  Good recall    = {sum(good_accs)} / {len(good_accs)} = {goodR:.3f}')
    print (f'  Good F0.5      = {goodF05:.3f}')
    print (f'  Bad precision  = {sum(bad_accs)} / {(len(good_accs)-sum(good_accs)+sum(bad_accs))} = {badP:.3f}')
    print (f'  Bad recall     = {sum(bad_accs)} / {len(bad_accs)} = {badR:.3f}')
    print (f'  Bad F0.5       = {badF05:.3f}')
    return {'goodP': goodP, 'goodR': goodR, 'goodF05': goodF05, 'badP': badP, 'badR': badR, 'badF05': badF05}

def evaluate_baseline_critic():
    threshold = np.mean([elm['logp'] for elm in good_logps + bad_logps])
    good_accs, bad_accs = [], []
    for obj in good_logps:
        pred = int(obj['logp'] > threshold)
        good_accs.append(pred==1)
    for obj in bad_logps:
        pred = int(obj['logp'] > threshold)
        bad_accs.append(pred==0)
    print ('\nBaseline critic:')
    stats = compute_metrics(good_accs, bad_accs)
    json.dump(stats, open('baseline_critic.stats.json', 'w'), indent=2)

evaluate_baseline_critic()
# Baseline critic:
#   Good precision = 365 / 668 = 0.546
#   Good recall    = 365 / 586 = 0.623
#   Good F0.5      = 0.560
#   Bad precision  = 283 / 504 = 0.562
#   Bad recall     = 283 / 586 = 0.483
#   Bad F0.5       = 0.544


def evaluate_LM_Critic():
    good_accs, bad_accs = [], []
    for obj in tqdm(good_logps):
        res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
        pred = int(res[0])
        good_accs.append(pred==1)
    for obj in tqdm(bad_logps):
        res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
        pred = int(res[0])
        bad_accs.append(pred==0)
    print ('\nLM-Critic:')
    stats = compute_metrics(good_accs, bad_accs)
    json.dump(stats, open('lm_critic.stats.json', 'w'), indent=2)

evaluate_LM_Critic()
# LM-Critic: (there is variance due to the randomness of sampling, some variation in GPT2 return score)
#   Good precision = 446 / 654 = 0.682
#   Good recall    = 446 / 586 = 0.761
#   Good F0.5      = 0.696
#   Bad precision  = 378 / 518 = 0.730
#   Bad recall     = 378 / 586 = 0.645
#   Bad F0.5       = 0.711