Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
import os | |
import openai | |
import time | |
NUM_SECONDS_TO_SLEEP = 0.5 | |
def get_eval(content: str, max_tokens: int): | |
while True: | |
try: | |
response = openai.ChatCompletion.create( | |
model='gpt-4-0314', | |
messages=[{ | |
'role': 'system', | |
'content': 'You are a helpful and precise assistant for checking the quality of the answer.' | |
}, { | |
'role': 'user', | |
'content': content, | |
}], | |
temperature=0.2, # TODO: figure out which temperature is best for evaluation | |
max_tokens=max_tokens, | |
) | |
break | |
except openai.error.RateLimitError: | |
pass | |
except Exception as e: | |
print(e) | |
time.sleep(NUM_SECONDS_TO_SLEEP) | |
return response['choices'][0]['message']['content'] | |
def parse_score(review): | |
try: | |
score_pair = review.split('\n')[0] | |
score_pair = score_pair.replace(',', ' ') | |
sp = score_pair.split(' ') | |
if len(sp) == 2: | |
return [float(sp[0]), float(sp[1])] | |
else: | |
print('error', review) | |
return [-1, -1] | |
except Exception as e: | |
print(e) | |
print('error', review) | |
return [-1, -1] | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') | |
parser.add_argument('-q', '--question') | |
parser.add_argument('-c', '--context') | |
parser.add_argument('-a', '--answer-list', nargs='+', default=[]) | |
parser.add_argument('-r', '--rule') | |
parser.add_argument('-o', '--output') | |
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') | |
args = parser.parse_args() | |
f_q = open(os.path.expanduser(args.question)) | |
f_ans1 = open(os.path.expanduser(args.answer_list[0])) | |
f_ans2 = open(os.path.expanduser(args.answer_list[1])) | |
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) | |
if os.path.isfile(os.path.expanduser(args.output)): | |
cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] | |
else: | |
cur_reviews = [] | |
review_file = open(f'{args.output}', 'a') | |
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] | |
image_to_context = {context['image']: context for context in context_list} | |
handles = [] | |
idx = 0 | |
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): | |
ques = json.loads(ques_js) | |
ans1 = json.loads(ans1_js) | |
ans2 = json.loads(ans2_js) | |
inst = image_to_context[ques['image']] | |
if isinstance(inst['caption'], list): | |
cap_str = '\n'.join(inst['caption']) | |
else: | |
cap_str = inst['caption'] | |
category = 'llava_bench_' + json.loads(ques_js)['category'] | |
if category in rule_dict: | |
rule = rule_dict[category] | |
else: | |
assert False, f"Visual QA category not found in rule file: {category}." | |
prompt = rule['prompt'] | |
role = rule['role'] | |
content = (f'[Context]\n{cap_str}\n\n' | |
f'[Question]\n{ques["text"]}\n\n' | |
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' | |
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' | |
f'[System]\n{prompt}\n\n') | |
cur_js = { | |
'id': idx+1, | |
'question_id': ques['question_id'], | |
'answer1_id': ans1.get('answer_id', ans1['question_id']), | |
'answer2_id': ans2.get('answer_id', ans2['answer_id']), | |
'category': category | |
} | |
if idx >= len(cur_reviews): | |
review = get_eval(content, args.max_tokens) | |
scores = parse_score(review) | |
cur_js['content'] = review | |
cur_js['tuple'] = scores | |
review_file.write(json.dumps(cur_js) + '\n') | |
review_file.flush() | |
else: | |
print(f'Skipping {idx} as we already have it.') | |
idx += 1 | |
print(idx) | |
review_file.close() | |