cadspace / CADFusion /src /dpo /make_dpo_dataset.py
kshdes37's picture
Upload 50 files
91daf98 verified
raw
history blame
6.27 kB
import json
import os
import time
import argparse
from openai_utils import ask_gpt_on_figure, ask_gpt
from llava_utils import ask_llm, ask_llm_on_figure, restart_model
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--source-data-path", type=str, required=True)
parser.add_argument("--figure-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--num-samples", type=int, required=True)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--score-only", action="store_true", default=False)
parser.add_argument("--gpt", action="store_true", default=False)
args = parser.parse_args()
source_path = args.source_data_path
folder_path = args.figure_path
save_path = args.save_path
num_samples = args.num_samples
device=f'cuda:{args.gpu}'
if args.gpt:
func1, func2 = ask_gpt_on_figure, ask_gpt
model = None
processor = None
else:
func1, func2 = ask_llm_on_figure, ask_llm
model, processor = restart_model(device)
with open(source_path, 'r') as f:
test_data = json.load(f)
####### Stage 1 #######
# for model generations that are able to render pictures,
# ask gpt to rate the generation quality.
for data in tqdm(test_data):
file_id = str(data['index']).zfill(6)
file = None
for f in os.listdir(folder_path):
if f.startswith(file_id):
file = folder_path + f
data['figure_path'] = file
error_cnt = 0
while 1:
try:
data['gpt_label'] = func1(data, model, processor)
break
except Exception as e:
print(e)
if args.gpt:
time.sleep(3)
else:
if error_cnt == 5:
exit()
model, processor = restart_model(device)
error_cnt += 1
with open(save_path, 'w+') as f:
json.dump(test_data, f, indent=4)
with open(save_path, 'r') as f:
test_data = json.load(f)
####### Stage 2 #######
# clean up the dataset to summarize the generation quality estimation to a numerical score, and
# remove the failed ones, i.e. the generations that cannot render
for data in tqdm(test_data):
if "gpt_label" in data.keys():
error_cnt = 0
while 1:
try:
score = func2(data, model, processor)
print(score)
break
except Exception as e:
print(e)
if args.gpt:
time.sleep(3)
else:
if error_cnt == 5:
exit()
model, processor = restart_model(device)
error_cnt += 1
try:
data['gpt_score'] = int(score)
except:
print(f'ERROR: {score}')
pass
saved_data = [data for data in test_data if 'gpt_score' in data.keys()]
with open(save_path, 'w+') as f:
json.dump(saved_data, f, indent=4)
if args.score_only:
exit()
####### Stage 3 #######
# 1. group up the scored generations by their description: we do not compare
# generation results that come from different origin prompts
temp_data = []
max_idx = test_data[-1]['index']
sample_size = max_idx // num_samples + 1
# a. select if any above 6
# for i in range(sample_size):
# next_sample = test_data[i*num_samples:(i+1)*num_samples]
# next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
# above_score = [item['gpt_score'] >= 6 for item in next_sample]
# if any(above_score):
# temp_data.extend(next_sample)
# temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
# b. select if avg above 6
# for i in range(sample_size):
# next_sample = test_data[i*num_samples:(i+1)*num_samples]
# next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
# if len(next_sample) == 0:
# continue
# scores = sum(item['gpt_score'] for item in next_sample) / len(next_sample)
# if scores >= 6:
# temp_data.extend(next_sample)
# temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
# c. select if individual above 6
test_data = saved_data
for item in test_data:
if 'gpt_score' not in item.keys():
continue
if item['gpt_score'] >= 6:
temp_data.append(item)
print(test_data[-1]['index'], max_idx)
grouped = [[] for _ in range(max_idx)]
for item in temp_data:
idx = item['index']
grouped[idx // num_samples].append(item)
grouped = [item for item in grouped if len(item) > 0]
# 2. within each group, make pairs where the chosens have higher score than the rejected ones.
# TODO: find a way to balance the data generated from each group
final_data = []
for group in grouped:
for item1 in group:
for item2 in group:
if item2['gpt_score'] > item1['gpt_score']:
info_dict = {
"description": item1['description'],
"prompt": item1['prompt'],
"chosen": item2['output'],
"rejected": item1['output']
}
final_data.append(info_dict)
# uncomment this break if you do not want too many data.
# break
with open(save_path, 'w+') as f:
json.dump(final_data, f, indent=4)