live-lm-critic / eval_critic /eval_critic.py
Olivia Figueira
Upload code with streamlit addition
b6e5241
import re
import os
import sys
import json
import numpy as np
import editdistance
from tqdm import tqdm
from collections import Counter
sys.path.insert(0, '.')
from utils.text_utils import detokenize_sent
from critic.critic import run_gpt2, gpt2_critic
def load_data():
data_path = 'eval_critic/eval_data.jsonl'
good_sents, bad_sents = [], []
for line in open(data_path):
obj = json.loads(line)
good_sents.append(obj['good'])
bad_sents.append(obj['bad'])
return good_sents, bad_sents
good_sents, bad_sents = load_data()
def get_logps(sents):
final = []
for start in tqdm(range(0, len(sents), 100)):
sents_sub = sents[start: start+100]
sents_sub_detok = [detokenize_sent(sent) for sent in sents_sub]
logps = run_gpt2(sents_sub_detok)
assert logps is not None
for i in range(len(sents_sub)):
final.append({'sent': sents_sub[i], 'sent_detok': sents_sub_detok[i], 'logp': float(logps[i])})
return final
def evaluate_logp():
"""
Check whether log p(bad_sent) < log p(good_sent)
"""
good_logps = get_logps(good_sents)
bad_logps = get_logps(bad_sents)
accs = []
for good, bad in zip(good_logps, bad_logps):
accs.append(int(bad['logp'] < good['logp']))
avg_acc = float(sum(accs))/len(accs)
print (f'log p(bad) < log p(good)? {sum(accs)} / {len(accs)} = {avg_acc:.3f}')
return good_logps, bad_logps
good_logps, bad_logps = evaluate_logp()
# log p(bad) < log p(good)? 555 / 586 = 0.947
def compute_metrics(good_accs, bad_accs):
goodP = float(sum(good_accs))/(len(bad_accs)-sum(bad_accs)+sum(good_accs))
goodR = float(sum(good_accs))/len(good_accs)
goodF05 = (1+0.5**2) * float(goodP * goodR)/((0.5**2 * goodP) + goodR)
badP = float(sum(bad_accs))/(len(good_accs)-sum(good_accs)+sum(bad_accs))
badR = float(sum(bad_accs))/len(bad_accs)
badF05 = (1+0.5**2) * float(badP * badR)/((0.5**2 * badP) + badR)
print (f' Good precision = {sum(good_accs)} / {(len(bad_accs)-sum(bad_accs)+sum(good_accs))} = {goodP:.3f}')
print (f' Good recall = {sum(good_accs)} / {len(good_accs)} = {goodR:.3f}')
print (f' Good F0.5 = {goodF05:.3f}')
print (f' Bad precision = {sum(bad_accs)} / {(len(good_accs)-sum(good_accs)+sum(bad_accs))} = {badP:.3f}')
print (f' Bad recall = {sum(bad_accs)} / {len(bad_accs)} = {badR:.3f}')
print (f' Bad F0.5 = {badF05:.3f}')
return {'goodP': goodP, 'goodR': goodR, 'goodF05': goodF05, 'badP': badP, 'badR': badR, 'badF05': badF05}
def evaluate_baseline_critic():
threshold = np.mean([elm['logp'] for elm in good_logps + bad_logps])
good_accs, bad_accs = [], []
for obj in good_logps:
pred = int(obj['logp'] > threshold)
good_accs.append(pred==1)
for obj in bad_logps:
pred = int(obj['logp'] > threshold)
bad_accs.append(pred==0)
print ('\nBaseline critic:')
stats = compute_metrics(good_accs, bad_accs)
json.dump(stats, open('baseline_critic.stats.json', 'w'), indent=2)
evaluate_baseline_critic()
# Baseline critic:
# Good precision = 365 / 668 = 0.546
# Good recall = 365 / 586 = 0.623
# Good F0.5 = 0.560
# Bad precision = 283 / 504 = 0.562
# Bad recall = 283 / 586 = 0.483
# Bad F0.5 = 0.544
def evaluate_LM_Critic():
good_accs, bad_accs = [], []
for obj in tqdm(good_logps):
res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
pred = int(res[0])
good_accs.append(pred==1)
for obj in tqdm(bad_logps):
res = gpt2_critic(obj['sent_detok'], verbose=0, seed=1, n_samples=100, word_level_mode='refine')
pred = int(res[0])
bad_accs.append(pred==0)
print ('\nLM-Critic:')
stats = compute_metrics(good_accs, bad_accs)
json.dump(stats, open('lm_critic.stats.json', 'w'), indent=2)
evaluate_LM_Critic()
# LM-Critic: (there is variance due to the randomness of sampling, some variation in GPT2 return score)
# Good precision = 446 / 654 = 0.682
# Good recall = 446 / 586 = 0.761
# Good F0.5 = 0.696
# Bad precision = 378 / 518 = 0.730
# Bad recall = 378 / 586 = 0.645
# Bad F0.5 = 0.711