MQT-LLaVA / llava /eval /eval_gpt_review_visual.py
gordonhubackup's picture
init
140efa3
raw
history blame
4.18 kB
import argparse
import json
import os
import openai
import time
NUM_SECONDS_TO_SLEEP = 0.5
def get_eval(content: str, max_tokens: int):
while True:
try:
response = openai.ChatCompletion.create(
model='gpt-4-0314',
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
}, {
'role': 'user',
'content': content,
}],
temperature=0.2, # TODO: figure out which temperature is best for evaluation
max_tokens=max_tokens,
)
break
except openai.error.RateLimitError:
pass
except Exception as e:
print(e)
time.sleep(NUM_SECONDS_TO_SLEEP)
return response['choices'][0]['message']['content']
def parse_score(review):
try:
score_pair = review.split('\n')[0]
score_pair = score_pair.replace(',', ' ')
sp = score_pair.split(' ')
if len(sp) == 2:
return [float(sp[0]), float(sp[1])]
else:
print('error', review)
return [-1, -1]
except Exception as e:
print(e)
print('error', review)
return [-1, -1]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
parser.add_argument('-q', '--question')
parser.add_argument('-c', '--context')
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
parser.add_argument('-r', '--rule')
parser.add_argument('-o', '--output')
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
args = parser.parse_args()
f_q = open(os.path.expanduser(args.question))
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
if os.path.isfile(os.path.expanduser(args.output)):
cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
else:
cur_reviews = []
review_file = open(f'{args.output}', 'a')
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
image_to_context = {context['image']: context for context in context_list}
handles = []
idx = 0
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
ques = json.loads(ques_js)
ans1 = json.loads(ans1_js)
ans2 = json.loads(ans2_js)
inst = image_to_context[ques['image']]
cap_str = '\n'.join(inst['captions'])
box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
category = json.loads(ques_js)['category']
if category in rule_dict:
rule = rule_dict[category]
else:
assert False, f"Visual QA category not found in rule file: {category}."
prompt = rule['prompt']
role = rule['role']
content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
f'[Question]\n{ques["text"]}\n\n'
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
f'[System]\n{prompt}\n\n')
cur_js = {
'id': idx+1,
'question_id': ques['question_id'],
'answer1_id': ans1.get('answer_id', ans1['question_id']),
'answer2_id': ans2.get('answer_id', ans2['answer_id']),
'category': category
}
if idx >= len(cur_reviews):
review = get_eval(content, args.max_tokens)
scores = parse_score(review)
cur_js['content'] = review
cur_js['tuple'] = scores
review_file.write(json.dumps(cur_js) + '\n')
review_file.flush()
else:
print(f'Skipping {idx} as we already have it.')
idx += 1
print(idx)
review_file.close()