Spaces:
Runtime error
Runtime error
File size: 4,268 Bytes
b6e5241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import re
import os
import sys
import json
import numpy as np
import editdistance
from tqdm import tqdm
from collections import Counter
sys.path.insert(0, '.')
from utils.text_utils import detokenize_sent
from critic.critic import run_gpt2, gpt2_critic
def load_data():
data_path = 'eval_critic/eval_data.jsonl'
good_sents, bad_sents = [], []
for line in open(data_path):
obj = json.loads(line)
good_sents.append(obj['good'])
bad_sents.append(obj['bad'])
return good_sents, bad_sents
good_sents, bad_sents = load_data()
def get_logps(sents):
final = []
for start in tqdm(range(0, len(sents), 100)):
sents_sub = sents[start: start+100]
sents_sub_detok = [detokenize_sent(sent) for sent in sents_sub]
logps = run_gpt2(sents_sub_detok)
assert logps is not None
for i in range(len(sents_sub)):
final.append({'sent': sents_sub[i], 'sent_detok': sents_sub_detok[i], 'logp': float(logps[i])})
return final
def evaluate_logp():
"""
Check whether log p(bad_sent) < log p(good_sent)
"""
good_logps = get_logps(good_sents)
bad_logps = get_logps(bad_sents)
accs = []
for good, bad in zip(good_logps, bad_logps):
accs.append(int(bad['logp'] < good['logp']))
avg_acc = float(sum(accs))/len(accs)
print (f'log p(bad) < log p(good)? {sum(accs)} / {len(accs)} = {avg_acc:.3f}')
return good_logps, bad_logps
good_logps, bad_logps = evaluate_logp()
# log p(bad) < log p(good)? 555 / 586 = 0.947
def compute_metrics(good_accs, bad_accs):
goodP = float(sum(good_accs))/(len(bad_accs)-sum(bad_accs)+sum(good_accs))
goodR = float(sum(good_accs))/len(good_accs)
goodF05 = (1+0.5**2) * float(goodP * goodR)/((0.5**2 * goodP) + goodR)
badP = float(sum(bad_accs))/(len(good_accs)-sum(good_accs)+sum(bad_accs))
badR = float(sum(bad_accs))/len(bad_accs)
badF05 = (1+0.5**2) * float(badP * badR)/((0.5**2 * badP) + badR)
print (f' Good precision = {sum(good_accs)} / {(len(bad_accs)-sum(bad_accs)+sum(good_accs))} = {goodP:.3f}')
print (f' Good recall = {sum(good_accs)} / {len(good_accs)} = {goodR:.3f}')
print (f' Good F0.5 = {goodF05:.3f}')
print (f' Bad precision = {sum(bad_accs)} / {(len(good_accs)-sum(good_accs)+sum(bad_accs))} = {badP:.3f}')
print (f' Bad recall = {sum(bad_accs)} / {len(bad_accs)} = {badR:.3f}')
print (f' Bad F0.5 = {badF05:.3f}')
return {'goodP': goodP, 'goodR': goodR, 'goodF05': goodF05, 'badP': badP, 'badR': badR, 'badF05': badF05}
def evaluate_baseline_critic():
threshold = np.mean([elm['logp'] for elm in good_logps + bad_logps])
good_accs, bad_accs = [], []
for obj in good_logps:
pred = int(obj['logp'] > threshold)
good_accs.append(pred==1)
for obj in bad_logps:
pred = int(obj['logp'] > threshold)
bad_accs.append(pred==0)
print ('\nBaseline critic:')
stats = compute_metrics(good_accs, bad_accs)
json.dump(stats, open('baseline_critic.stats.json', 'w'), indent=2)
evaluate_baseline_critic()
# Baseline critic:
# Good precision = 365 / 668 = 0.546
# Good recall = 365 / 586 = 0.623
# Good F0.5 = 0.560
# Bad precision = 283 / 504 = 0.562
# Bad recall = 283 / 586 = 0.483
# Bad F0.5 = 0.544
def evaluate_LM_Critic():
good_accs, bad_accs = [], []
for obj in tqdm(good_logps):
res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
pred = int(res[0])
good_accs.append(pred==1)
for obj in tqdm(bad_logps):
res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
pred = int(res[0])
bad_accs.append(pred==0)
print ('\nLM-Critic:')
stats = compute_metrics(good_accs, bad_accs)
json.dump(stats, open('lm_critic.stats.json', 'w'), indent=2)
evaluate_LM_Critic()
# LM-Critic: (there is variance due to the randomness of sampling, some variation in GPT2 return score)
# Good precision = 446 / 654 = 0.682
# Good recall = 446 / 586 = 0.761
# Good F0.5 = 0.696
# Bad precision = 378 / 518 = 0.730
# Bad recall = 378 / 586 = 0.645
# Bad F0.5 = 0.711
|