Spaces:
Running
Running
# encoding:utf-8 | |
import ast | |
import csv | |
import json | |
import logging | |
import os | |
import random | |
import re | |
import time | |
from copy import deepcopy | |
from datetime import datetime | |
import numpy as np | |
from tenacity import RetryError, before_sleep_log, retry, stop_after_attempt, wait_exponential_jitter # for exponential backoff | |
from code_interpreter import CodeInterpreter | |
# from config import confs_nr3d, confs_scanrefer, confs_sr3d | |
# from gpt_dialogue import Dialogue | |
# from object_filter_gpt4 import ObjectFilter | |
from prompt_text import get_principle, get_principle_sr3d, get_system_message | |
logger = logging.getLogger(__name__ + 'logger') | |
logger.setLevel(logging.ERROR) | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.ERROR) | |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
def round_list(lst, length): | |
# round every element in lst | |
for idx, num in enumerate(lst): | |
lst[idx] = round(num, length) | |
return list(lst) | |
def remove_spaces(s: str): | |
return s.replace(' ', '') | |
def rgb_to_hsl(rgb): | |
# Normalize RGB values to the range [0, 1] | |
r, g, b = [x / 255.0 for x in rgb] | |
# Calculate min and max values of RGB to find chroma | |
c_max = max(r, g, b) | |
c_min = min(r, g, b) | |
chroma = c_max - c_min | |
# Calculate lightness | |
lightness = (c_max + c_min) / 2 | |
# Calculate hue and saturation | |
hue = 0 | |
saturation = 0 | |
if chroma != 0: | |
if c_max == r: | |
hue = ((g - b) / chroma) % 6 | |
elif c_max == g: | |
hue = ((b - r) / chroma) + 2 | |
elif c_max == b: | |
hue = ((r - g) / chroma) + 4 | |
hue *= 60 | |
# Calculate saturation | |
if lightness <= 0.5: | |
saturation = chroma / (2 * lightness) | |
else: | |
saturation = chroma / (2 - 2 * lightness) | |
return [hue, saturation, lightness] | |
def get_scene_center(objects): | |
xmin, ymin, zmin = float('inf'), float('inf'), float('inf') | |
xmax, ymax, zmax = float('-inf'), float('-inf'), float('-inf') | |
for obj in objects: | |
x, y, z = obj['center_position'] | |
if x < xmin: | |
xmin = x | |
if x > xmax: | |
xmax = x | |
if y < ymin: | |
ymin = y | |
if y > ymax: | |
ymax = y | |
if z < zmin: | |
zmin = z | |
if z > zmax: | |
zmax = z | |
return round_list([(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2], 2) | |
def find_relevant_objects(user_instruction, scan_id): | |
pass | |
def gen_prompt(user_instruction, scan_id): | |
npy_path = os.path.join("objects_info", f"objects_info_{scan_id}.npy") | |
objects_info = np.load(npy_path, allow_pickle=True) | |
# objects_related = find_relevant_objects(user_instruction, scan_id) | |
objects_related = objects_info | |
# 获取场景的中心坐标 | |
# scene_center=get_scene_center(objects_related) | |
scene_center = get_scene_center(objects_info) # 注意这里应该用所有物体的信息,而不只是relevant | |
# 生成prompt中的背景信息部分 | |
prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n" | |
# if dataset == 'nr3d': | |
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center)) | |
# elif dataset == 'scanrefer': | |
# if use_camera_position: | |
# prompt = prompt + "Scene center:%s.\n" % remove_spaces(str(scene_center)) | |
# prompt = prompt + "Observer position:%s.\n" % remove_spaces(str(round_list(camera_info_aligned['position'], 2))) | |
# else: | |
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center)) | |
prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center)) | |
prompt = prompt + "objs list:\n" | |
lines = [] | |
# 生成prompt中对物体的定量描述部分(遍历所有相关物体) | |
for obj in objects_related: | |
# 位置信息,保留2位小数 | |
center_position = obj['center_position'] | |
center_position = round_list(center_position, 2) | |
# size信息,保留2位小数 | |
size = obj['size'] | |
size = round_list(size, 2) | |
# extension信息,保留2位小数 | |
extension = obj['extension'] | |
extension = round_list(extension, 2) | |
# 方向信息,用方向向量表示. 注意,scanrefer由于用的不是scannet原始obj id,所以不能用方向信息 | |
if obj['has_front']: | |
front_point = np.array(obj['front_point']) | |
center = np.array(obj['obb'][0:3]) | |
direction_vector = front_point - center | |
direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector) | |
# 再计算左和右的方向向量,全部保留两位小数 | |
front_vector = round_list(direction_vector_normalized, 2) | |
up_vector = np.array([0, 0, 1]) | |
left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2) | |
right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2) | |
behind_vector = round_list(-np.array(front_vector), 2) | |
# 生成方向信息 | |
direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector) | |
# | |
else: | |
direction_info = "\n" # 未知方向向量就啥都不写 | |
# sr3d,给出center、size | |
# if dataset == 'sr3d': | |
if False: | |
line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}' | |
# nr3d和scanrefer,给出center、size、color | |
else: | |
rgb = obj['avg_rgba'][0:3] | |
hsl = round_list(rgb_to_hsl(rgb), 2) | |
# line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) 原版rgb | |
line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#rgb换成hsl | |
# line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # 格式:name=原名称(description里的名称) | |
# if id_to_name_in_description[obj['id']]=='room': | |
# name=obj['label'] | |
# else: | |
# name=id_to_name_in_description[obj['id']] | |
# line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # 式:name=description里的名称 | |
lines.append(line + direction_info) | |
# if self.obj_info_ablation_type == 4: | |
# random.seed(0) | |
# random.shuffle(lines) | |
prompt += ''.join(lines) | |
# prompt中的要求 | |
line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction | |
prompt = prompt + line | |
prompt = prompt + "\n\nThere is exactly one answer, so if you receive multiple answers, considerother constraints; if get no answers, loosen constraints." | |
prompt = prompt + "\n\nWork this out step by step to ensure right answer." | |
prompt = prompt + "\n\nIf the answer is complete, add \"Now the answer is complete -- {'ID':id}\" to the end of your answer(that is, your completion, not your code), where id is the id of the referred obj. Do not add anything after." | |
return prompt | |
# 20s,40s,80s,120s + random.uniform(0,20) | |
def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter): | |
print("llm_name:",code_interpreter.model) | |
# get response from GPT(using code interpreter). using retry from tenacity. | |
# count the token usage and time as well | |
# if the reponse does not include "Now the answer is complete", this means the answer is notdone. attach an empty user message to let GPT to keep going. | |
# start timing | |
call_start_time = time.time() | |
# the first call with the original prompt | |
response, token_usage_total = code_interpreter.call_openai_with_code_interpreter(prompt) | |
response = response['content'] | |
# loop until "Now the answer is complete" is in the response, or looping more than 10 times. | |
count_response = 0 | |
while not "Now the answer is complete" in response: | |
if count_response >= 10: | |
print("Response does not end with 'Now the answer is complete.' !") | |
break | |
response, token_usage_add = code_interpreter.call_openai_with_code_interpreter('') | |
response = response['content'] | |
token_usage_total += token_usage_add | |
count_response += 1 | |
print("count_response:", count_response) | |
# stop timing | |
call_end_time = time.time() | |
time_consumed = call_end_time - call_start_time | |
# self.token_usage_this_ques += token_usage_total | |
# self.token_usage_whole_run += token_usage_total | |
# self.time_consumed_this_ques += time_consumed | |
# self.time_consumed_whole_run += time_consumed | |
# print("\n*** Refer model: token usage=%d, time consumed=%ds, TPM=%.2f ***" %(token_usage_total, time_consumed, token_usage_total / time_consumed * 60)) | |
return response | |
def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]): | |
# 如果没有按照预期格式回复则随机选取(Sr3d)或直接选成0(Nr3d和Scanrefer);按预期格式恢复则提取答案 | |
wrong_return_format = False | |
last_line_split = last_line.split('--') | |
# 使用正则表达式从字符串中提取字典部分 | |
pattern = r"\{[^\}]*\}" | |
match = re.search(pattern, last_line_split[-1]) | |
if match: | |
# 获取匹配的字典字符串 | |
matched_dict_str = match.group() | |
try: | |
# 解析字典字符串为字典对象 | |
extracted_dict = ast.literal_eval(matched_dict_str) | |
print(extracted_dict) | |
answer_id = extracted_dict['ID'] | |
# 如果确实以 Now the answer is complete -- {'ID': xxx} 的格式回复了,但是xxx不是数字(例如是None),也能随机选。 | |
if not isinstance(answer_id, int): | |
if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]): | |
print("Wrong answer format: %s. random choice from this list" % str(answer_id)) | |
answer_id = random.choice(answer_id) | |
else: | |
print("Wrong answer format: %s. No dict found. Random choice from relevant objects." % str(answer_id)) | |
answer_id = random.choice(random_choice_list) | |
wrong_return_format = True | |
except BaseException: | |
print("Wrong answer format!! No dict found. Random choice.") | |
answer_id = random.choice(random_choice_list) | |
wrong_return_format = True | |
else: | |
print("Wrong answer format!! No dict found. Random choice.") | |
answer_id = random.choice(random_choice_list) | |
wrong_return_format = True | |
return answer_id, wrong_return_format | |
def get_openai_config(llm_name='gpt-3.5-turbo-0125'): | |
system_message = "" | |
system_message += get_system_message() | |
system_message += get_principle() | |
openai_config = { | |
# 'model': 'gpt-4-turbo-preview', | |
'model': llm_name, | |
'temperature': 1e-7, | |
'top_p': 1e-7, | |
# 'max_tokens': 4096, | |
'max_tokens': 8192, | |
'system_message': system_message, | |
# 'load_path': '', | |
'save_path': 'chats', | |
'debug': True | |
} | |
return openai_config | |
if __name__ == "__main__": | |
# system_message = 'Imagine you are an artificial intelligence assistant. You job is to do 3D referring reasoning, namely to find the object for a given utterance from a 3d scene presented as object-centric semantic information.\n' | |
system_message = "" | |
system_message += get_system_message() | |
system_message += get_principle() | |
openai_config = { | |
'model': 'gpt-4', | |
'temperature': 1e-7, | |
'top_p': 1e-7, | |
# 'max_tokens': 4096, | |
'max_tokens': 8192, | |
'system_message': system_message, | |
# 'load_path': '', | |
'save_path': 'chats', | |
'debug': True | |
} | |
code_interpreter = CodeInterpreter(**openai_config) | |
prompt = gen_prompt("Find the chair next to the table.", "scene0132_00") | |
print(prompt) | |
response = get_gpt_response(prompt, code_interpreter) | |
# print(response) | |
print("-------pretext--------") | |
print(code_interpreter.pretext) | |