|
import json
|
|
import os
|
|
import time
|
|
import argparse
|
|
|
|
from openai_utils import ask_gpt_on_figure, ask_gpt
|
|
from llava_utils import ask_llm, ask_llm_on_figure, restart_model
|
|
from tqdm import tqdm
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--source-data-path", type=str, required=True)
|
|
parser.add_argument("--figure-path", type=str, required=True)
|
|
parser.add_argument("--save-path", type=str, required=True)
|
|
parser.add_argument("--num-samples", type=int, required=True)
|
|
parser.add_argument("--gpu", type=int, default=0)
|
|
parser.add_argument("--score-only", action="store_true", default=False)
|
|
parser.add_argument("--gpt", action="store_true", default=False)
|
|
args = parser.parse_args()
|
|
|
|
source_path = args.source_data_path
|
|
folder_path = args.figure_path
|
|
save_path = args.save_path
|
|
num_samples = args.num_samples
|
|
device=f'cuda:{args.gpu}'
|
|
if args.gpt:
|
|
func1, func2 = ask_gpt_on_figure, ask_gpt
|
|
model = None
|
|
processor = None
|
|
else:
|
|
func1, func2 = ask_llm_on_figure, ask_llm
|
|
model, processor = restart_model(device)
|
|
|
|
with open(source_path, 'r') as f:
|
|
test_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
for data in tqdm(test_data):
|
|
file_id = str(data['index']).zfill(6)
|
|
file = None
|
|
for f in os.listdir(folder_path):
|
|
if f.startswith(file_id):
|
|
file = folder_path + f
|
|
data['figure_path'] = file
|
|
error_cnt = 0
|
|
while 1:
|
|
try:
|
|
data['gpt_label'] = func1(data, model, processor)
|
|
break
|
|
except Exception as e:
|
|
print(e)
|
|
if args.gpt:
|
|
time.sleep(3)
|
|
else:
|
|
if error_cnt == 5:
|
|
exit()
|
|
model, processor = restart_model(device)
|
|
error_cnt += 1
|
|
with open(save_path, 'w+') as f:
|
|
json.dump(test_data, f, indent=4)
|
|
|
|
with open(save_path, 'r') as f:
|
|
test_data = json.load(f)
|
|
|
|
|
|
|
|
for data in tqdm(test_data):
|
|
if "gpt_label" in data.keys():
|
|
error_cnt = 0
|
|
while 1:
|
|
try:
|
|
score = func2(data, model, processor)
|
|
print(score)
|
|
break
|
|
except Exception as e:
|
|
print(e)
|
|
if args.gpt:
|
|
time.sleep(3)
|
|
else:
|
|
if error_cnt == 5:
|
|
exit()
|
|
model, processor = restart_model(device)
|
|
error_cnt += 1
|
|
try:
|
|
data['gpt_score'] = int(score)
|
|
except:
|
|
print(f'ERROR: {score}')
|
|
pass
|
|
|
|
saved_data = [data for data in test_data if 'gpt_score' in data.keys()]
|
|
with open(save_path, 'w+') as f:
|
|
json.dump(saved_data, f, indent=4)
|
|
|
|
if args.score_only:
|
|
exit()
|
|
|
|
|
|
|
|
|
|
temp_data = []
|
|
max_idx = test_data[-1]['index']
|
|
sample_size = max_idx // num_samples + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_data = saved_data
|
|
for item in test_data:
|
|
if 'gpt_score' not in item.keys():
|
|
continue
|
|
if item['gpt_score'] >= 6:
|
|
temp_data.append(item)
|
|
print(test_data[-1]['index'], max_idx)
|
|
|
|
grouped = [[] for _ in range(max_idx)]
|
|
for item in temp_data:
|
|
idx = item['index']
|
|
grouped[idx // num_samples].append(item)
|
|
grouped = [item for item in grouped if len(item) > 0]
|
|
|
|
|
|
|
|
final_data = []
|
|
for group in grouped:
|
|
for item1 in group:
|
|
for item2 in group:
|
|
if item2['gpt_score'] > item1['gpt_score']:
|
|
info_dict = {
|
|
"description": item1['description'],
|
|
"prompt": item1['prompt'],
|
|
"chosen": item2['output'],
|
|
"rejected": item1['output']
|
|
}
|
|
final_data.append(info_dict)
|
|
|
|
|
|
|
|
|
|
with open(save_path, 'w+') as f:
|
|
json.dump(final_data, f, indent=4) |