Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,503 Bytes
d6bc023 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import json
import os
import math
import copy
import argparse
import numpy as np
def write_jsonl(data, filename):
with open(filename, 'w') as f:
for item in data:
f.write(json.dumps(item) + '\n')
def RelaxedAccuracy(pred, gt):
try:
gt = float(gt)
pred = float(pred)
if gt == 0.0:
if pred == gt:
return 1.0
else:
return 0.0
else:
if abs(pred-gt) / gt <= 0.05:
return 1.0
else:
return 0.0
except:
if str(gt) == str(pred):
return 1.0
else:
return 0.0
def evaluate_cmds(cmds):
for cmd in cmds:
exec(cmd)
answer = eval('Answer')
if (isinstance(answer, list) or isinstance(answer, np.ndarray)) and len(answer) == 1:
answer = answer[0]
if isinstance(answer, list) or isinstance(answer, np.ndarray):
new_answer = answer[0]
for i in range(1, len(answer)-1):
new_answer = new_answer + ', ' + answer[i]
new_answer += ' and ' + answer[-1]
answer = new_answer
if isinstance(answer, bool) or isinstance(answer, np.bool_):
if answer == True:
answer = 'Yes'
elif answer == False:
answer = 'No'
return answer
def parse_model_output(cmdstr):
lines = cmdstr.split('\n')
new_lines = []
for line in lines:
if '<step>' in line or '</step>' in line:
line = line.replace('<step>', '').replace('</step>', '')
new_lines.append(line)
return new_lines
def chartqa_evaluator(data, key='final_model_answer'):
acc = 0
for item in data:
item['relaxed_acc'] = RelaxedAccuracy(item[key], item['gt_answer'].split('<pot_note>')[0])
if item['relaxed_acc'] == 1.0:
acc += 1
accuracy = acc/len(data)
return data, accuracy
def chartqapot_evaluator(output_data):
correct_items = []
wrong_items = []
error_items = []
output_data = copy.deepcopy(output_data)
acc = 0
for item in output_data:
# cmds = parse_gpt_cmd(gpt_item['eval_cmd'])
eval_cmds = parse_model_output(item['model_answer'])
try:
answer = evaluate_cmds(eval_cmds)
item['final_model_answer'] = str(answer)
except:
error_items.append(item)
item['final_model_answer'] = 'Execute <error>'
item['relaxed_acc'] = 0.0
continue
item['gt_answer'] = item['gt_answer'].split('<cot_note>')[0]
item['relaxed_acc'] = RelaxedAccuracy(str(answer), item['gt_answer'])
if item['relaxed_acc'] == 1.0:
correct_items.append(item)
else:
wrong_items.append(item)
total = len(output_data)
accuracy = len(correct_items)/total
error_rate = len(error_items)/total
return output_data, accuracy, error_rate
def rule_based_divider(question):
calculate_words = [
'sum', 'difference', 'times', 'summation', 'exceed',
'below', 'addition', 'fewer', 'subtract', ' mode ',
'ratio', 'division', 'average', 'mean', 'bigger',
'greater', ' less ', 'tallest', 'number', 'divide',
' add ', 'absolute', 'dividing', 'differ', ' minus ',
'how many colors', 'lowest', 'what is the value', 'higher',
'longer', ' biggest ', 'lowest'
]
for w in calculate_words:
if w in question.lower():
return 'pot'
return 'direct'
def chartqa_rule_merger_evaluator(direct_data, pot_data):
direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
acc_count = 0
merged_data = []
for datum1, datum2 in zip(direct_data, pot_data):
if rule_based_divider(datum1['question']) == 'pot' and '<error>' not in datum2['final_model_answer'] and datum2['final_model_answer'] not in ['inf', '-inf', 'nan', 'np.nan', 'np.inf', '-np.inf']:
acc_count += datum2['relaxed_acc']
merged_data.append(datum2)
else:
acc_count += datum1['relaxed_acc']
merged_data.append(datum1)
accuracy = acc_count/len(direct_data)
return merged_data, accuracy
def chartqa_oracle_merger_evaluator(direct_data, pot_data):
direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
acc_count = 0
merged_data = []
for datum1, datum2 in zip(direct_data, pot_data):
if datum1['relaxed_acc'] != 1.0:
acc_count += datum2['relaxed_acc']
merged_data.append(datum2)
else:
acc_count += datum1['relaxed_acc']
merged_data.append(datum1)
accuracy = acc_count/len(direct_data)
return merged_data, accuracy
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--direct', default='../eval_iter12000_0226/ChartQA_test_12000_pred.jsonl')
parser.add_argument('--pot', default='../eval_iter12000_0226/ChartQA_test_pot_12000_eval.jsonl')
parser.add_argument('--output', default='../eval_iter12000_0226/ChartQA_test_pot_12000_merged.jsonl')
args = parser.parse_args()
merged = oracle_merger(args.direct, args.pot)
merged = rule_based_merger(args.direct, args.pot)
write_jsonl(merged, args.output) |