TinyChart-3B / tinychart /eval /eval_metric.py
xzl12306's picture
first commit
d6bc023
raw
history blame contribute delete
No virus
5.5 kB
import os
import json
import os
import math
import copy
import argparse
import numpy as np
def write_jsonl(data, filename):
with open(filename, 'w') as f:
for item in data:
f.write(json.dumps(item) + '\n')
def RelaxedAccuracy(pred, gt):
try:
gt = float(gt)
pred = float(pred)
if gt == 0.0:
if pred == gt:
return 1.0
else:
return 0.0
else:
if abs(pred-gt) / gt <= 0.05:
return 1.0
else:
return 0.0
except:
if str(gt) == str(pred):
return 1.0
else:
return 0.0
def evaluate_cmds(cmds):
for cmd in cmds:
exec(cmd)
answer = eval('Answer')
if (isinstance(answer, list) or isinstance(answer, np.ndarray)) and len(answer) == 1:
answer = answer[0]
if isinstance(answer, list) or isinstance(answer, np.ndarray):
new_answer = answer[0]
for i in range(1, len(answer)-1):
new_answer = new_answer + ', ' + answer[i]
new_answer += ' and ' + answer[-1]
answer = new_answer
if isinstance(answer, bool) or isinstance(answer, np.bool_):
if answer == True:
answer = 'Yes'
elif answer == False:
answer = 'No'
return answer
def parse_model_output(cmdstr):
lines = cmdstr.split('\n')
new_lines = []
for line in lines:
if '<step>' in line or '</step>' in line:
line = line.replace('<step>', '').replace('</step>', '')
new_lines.append(line)
return new_lines
def chartqa_evaluator(data, key='final_model_answer'):
acc = 0
for item in data:
item['relaxed_acc'] = RelaxedAccuracy(item[key], item['gt_answer'].split('<pot_note>')[0])
if item['relaxed_acc'] == 1.0:
acc += 1
accuracy = acc/len(data)
return data, accuracy
def chartqapot_evaluator(output_data):
correct_items = []
wrong_items = []
error_items = []
output_data = copy.deepcopy(output_data)
acc = 0
for item in output_data:
# cmds = parse_gpt_cmd(gpt_item['eval_cmd'])
eval_cmds = parse_model_output(item['model_answer'])
try:
answer = evaluate_cmds(eval_cmds)
item['final_model_answer'] = str(answer)
except:
error_items.append(item)
item['final_model_answer'] = 'Execute <error>'
item['relaxed_acc'] = 0.0
continue
item['gt_answer'] = item['gt_answer'].split('<cot_note>')[0]
item['relaxed_acc'] = RelaxedAccuracy(str(answer), item['gt_answer'])
if item['relaxed_acc'] == 1.0:
correct_items.append(item)
else:
wrong_items.append(item)
total = len(output_data)
accuracy = len(correct_items)/total
error_rate = len(error_items)/total
return output_data, accuracy, error_rate
def rule_based_divider(question):
calculate_words = [
'sum', 'difference', 'times', 'summation', 'exceed',
'below', 'addition', 'fewer', 'subtract', ' mode ',
'ratio', 'division', 'average', 'mean', 'bigger',
'greater', ' less ', 'tallest', 'number', 'divide',
' add ', 'absolute', 'dividing', 'differ', ' minus ',
'how many colors', 'lowest', 'what is the value', 'higher',
'longer', ' biggest ', 'lowest'
]
for w in calculate_words:
if w in question.lower():
return 'pot'
return 'direct'
def chartqa_rule_merger_evaluator(direct_data, pot_data):
direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
acc_count = 0
merged_data = []
for datum1, datum2 in zip(direct_data, pot_data):
if rule_based_divider(datum1['question']) == 'pot' and '<error>' not in datum2['final_model_answer'] and datum2['final_model_answer'] not in ['inf', '-inf', 'nan', 'np.nan', 'np.inf', '-np.inf']:
acc_count += datum2['relaxed_acc']
merged_data.append(datum2)
else:
acc_count += datum1['relaxed_acc']
merged_data.append(datum1)
accuracy = acc_count/len(direct_data)
return merged_data, accuracy
def chartqa_oracle_merger_evaluator(direct_data, pot_data):
direct_data, _ = chartqa_evaluator(direct_data, key='model_answer')
assert len(direct_data) == len(pot_data), 'direct and pot num inconsistent'
acc_count = 0
merged_data = []
for datum1, datum2 in zip(direct_data, pot_data):
if datum1['relaxed_acc'] != 1.0:
acc_count += datum2['relaxed_acc']
merged_data.append(datum2)
else:
acc_count += datum1['relaxed_acc']
merged_data.append(datum1)
accuracy = acc_count/len(direct_data)
return merged_data, accuracy
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--direct', default='../eval_iter12000_0226/ChartQA_test_12000_pred.jsonl')
parser.add_argument('--pot', default='../eval_iter12000_0226/ChartQA_test_pot_12000_eval.jsonl')
parser.add_argument('--output', default='../eval_iter12000_0226/ChartQA_test_pot_12000_merged.jsonl')
args = parser.parse_args()
merged = oracle_merger(args.direct, args.pot)
merged = rule_based_merger(args.direct, args.pot)
write_jsonl(merged, args.output)