Transcrib3D-Demo / transcrib3d_main.py
kudo1026
initial
f27a827
raw
history blame
13.2 kB
# encoding:utf-8
import ast
import csv
import json
import logging
import os
import random
import re
import time
from copy import deepcopy
from datetime import datetime
import numpy as np
from tenacity import RetryError, before_sleep_log, retry, stop_after_attempt, wait_exponential_jitter # for exponential backoff
from code_interpreter import CodeInterpreter
# from config import confs_nr3d, confs_scanrefer, confs_sr3d
# from gpt_dialogue import Dialogue
# from object_filter_gpt4 import ObjectFilter
from prompt_text import get_principle, get_principle_sr3d, get_system_message
logger = logging.getLogger(__name__ + 'logger')
logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.ERROR)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
def round_list(lst, length):
# round every element in lst
for idx, num in enumerate(lst):
lst[idx] = round(num, length)
return list(lst)
def remove_spaces(s: str):
return s.replace(' ', '')
def rgb_to_hsl(rgb):
# Normalize RGB values to the range [0, 1]
r, g, b = [x / 255.0 for x in rgb]
# Calculate min and max values of RGB to find chroma
c_max = max(r, g, b)
c_min = min(r, g, b)
chroma = c_max - c_min
# Calculate lightness
lightness = (c_max + c_min) / 2
# Calculate hue and saturation
hue = 0
saturation = 0
if chroma != 0:
if c_max == r:
hue = ((g - b) / chroma) % 6
elif c_max == g:
hue = ((b - r) / chroma) + 2
elif c_max == b:
hue = ((r - g) / chroma) + 4
hue *= 60
# Calculate saturation
if lightness <= 0.5:
saturation = chroma / (2 * lightness)
else:
saturation = chroma / (2 - 2 * lightness)
return [hue, saturation, lightness]
def get_scene_center(objects):
xmin, ymin, zmin = float('inf'), float('inf'), float('inf')
xmax, ymax, zmax = float('-inf'), float('-inf'), float('-inf')
for obj in objects:
x, y, z = obj['center_position']
if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
if z < zmin:
zmin = z
if z > zmax:
zmax = z
return round_list([(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2], 2)
def find_relevant_objects(user_instruction, scan_id):
pass
def gen_prompt(user_instruction, scan_id):
npy_path = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
objects_info = np.load(npy_path, allow_pickle=True)
# objects_related = find_relevant_objects(user_instruction, scan_id)
objects_related = objects_info
# 获取场景的中心坐标
# scene_center=get_scene_center(objects_related)
scene_center = get_scene_center(objects_info) # 注意这里应该用所有物体的信息,而不只是relevant
# 生成prompt中的背景信息部分
prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
# if dataset == 'nr3d':
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
# elif dataset == 'scanrefer':
# if use_camera_position:
# prompt = prompt + "Scene center:%s.\n" % remove_spaces(str(scene_center))
# prompt = prompt + "Observer position:%s.\n" % remove_spaces(str(round_list(camera_info_aligned['position'], 2)))
# else:
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
prompt = prompt + "objs list:\n"
lines = []
# 生成prompt中对物体的定量描述部分(遍历所有相关物体)
for obj in objects_related:
# 位置信息,保留2位小数
center_position = obj['center_position']
center_position = round_list(center_position, 2)
# size信息,保留2位小数
size = obj['size']
size = round_list(size, 2)
# extension信息,保留2位小数
extension = obj['extension']
extension = round_list(extension, 2)
# 方向信息,用方向向量表示. 注意,scanrefer由于用的不是scannet原始obj id,所以不能用方向信息
if obj['has_front']:
front_point = np.array(obj['front_point'])
center = np.array(obj['obb'][0:3])
direction_vector = front_point - center
direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
# 再计算左和右的方向向量,全部保留两位小数
front_vector = round_list(direction_vector_normalized, 2)
up_vector = np.array([0, 0, 1])
left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
behind_vector = round_list(-np.array(front_vector), 2)
# 生成方向信息
direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
#
else:
direction_info = "\n" # 未知方向向量就啥都不写
# sr3d,给出center、size
# if dataset == 'sr3d':
if False:
line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
# nr3d和scanrefer,给出center、size、color
else:
rgb = obj['avg_rgba'][0:3]
hsl = round_list(rgb_to_hsl(rgb), 2)
# line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) 原版rgb
line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#rgb换成hsl
# line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # 格式:name=原名称(description里的名称)
# if id_to_name_in_description[obj['id']]=='room':
# name=obj['label']
# else:
# name=id_to_name_in_description[obj['id']]
# line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # 式:name=description里的名称
lines.append(line + direction_info)
# if self.obj_info_ablation_type == 4:
# random.seed(0)
# random.shuffle(lines)
prompt += ''.join(lines)
# prompt中的要求
line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
prompt = prompt + line
prompt = prompt + "\n\nThere is exactly one answer, so if you receive multiple answers, considerother constraints; if get no answers, loosen constraints."
prompt = prompt + "\n\nWork this out step by step to ensure right answer."
prompt = prompt + "\n\nIf the answer is complete, add \"Now the answer is complete -- {'ID':id}\" to the end of your answer(that is, your completion, not your code), where id is the id of the referred obj. Do not add anything after."
return prompt
@retry(wait=wait_exponential_jitter(initial=20, max=120, jitter=20), stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger, logging.ERROR)) # 20s,40s,80s,120s + random.uniform(0,20)
def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter):
print("llm_name:",code_interpreter.model)
# get response from GPT(using code interpreter). using retry from tenacity.
# count the token usage and time as well
# if the reponse does not include "Now the answer is complete", this means the answer is notdone. attach an empty user message to let GPT to keep going.
# start timing
call_start_time = time.time()
# the first call with the original prompt
response, token_usage_total = code_interpreter.call_openai_with_code_interpreter(prompt)
response = response['content']
# loop until "Now the answer is complete" is in the response, or looping more than 10 times.
count_response = 0
while not "Now the answer is complete" in response:
if count_response >= 10:
print("Response does not end with 'Now the answer is complete.' !")
break
response, token_usage_add = code_interpreter.call_openai_with_code_interpreter('')
response = response['content']
token_usage_total += token_usage_add
count_response += 1
print("count_response:", count_response)
# stop timing
call_end_time = time.time()
time_consumed = call_end_time - call_start_time
# self.token_usage_this_ques += token_usage_total
# self.token_usage_whole_run += token_usage_total
# self.time_consumed_this_ques += time_consumed
# self.time_consumed_whole_run += time_consumed
# print("\n*** Refer model: token usage=%d, time consumed=%ds, TPM=%.2f ***" %(token_usage_total, time_consumed, token_usage_total / time_consumed * 60))
return response
def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
# 如果没有按照预期格式回复则随机选取(Sr3d)或直接选成0(Nr3d和Scanrefer);按预期格式恢复则提取答案
wrong_return_format = False
last_line_split = last_line.split('--')
# 使用正则表达式从字符串中提取字典部分
pattern = r"\{[^\}]*\}"
match = re.search(pattern, last_line_split[-1])
if match:
# 获取匹配的字典字符串
matched_dict_str = match.group()
try:
# 解析字典字符串为字典对象
extracted_dict = ast.literal_eval(matched_dict_str)
print(extracted_dict)
answer_id = extracted_dict['ID']
# 如果确实以 Now the answer is complete -- {'ID': xxx} 的格式回复了,但是xxx不是数字(例如是None),也能随机选。
if not isinstance(answer_id, int):
if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
print("Wrong answer format: %s. random choice from this list" % str(answer_id))
answer_id = random.choice(answer_id)
else:
print("Wrong answer format: %s. No dict found. Random choice from relevant objects." % str(answer_id))
answer_id = random.choice(random_choice_list)
wrong_return_format = True
except BaseException:
print("Wrong answer format!! No dict found. Random choice.")
answer_id = random.choice(random_choice_list)
wrong_return_format = True
else:
print("Wrong answer format!! No dict found. Random choice.")
answer_id = random.choice(random_choice_list)
wrong_return_format = True
return answer_id, wrong_return_format
def get_openai_config(llm_name='gpt-3.5-turbo-0125'):
system_message = ""
system_message += get_system_message()
system_message += get_principle()
openai_config = {
# 'model': 'gpt-4-turbo-preview',
'model': llm_name,
'temperature': 1e-7,
'top_p': 1e-7,
# 'max_tokens': 4096,
'max_tokens': 8192,
'system_message': system_message,
# 'load_path': '',
'save_path': 'chats',
'debug': True
}
return openai_config
if __name__ == "__main__":
# system_message = 'Imagine you are an artificial intelligence assistant. You job is to do 3D referring reasoning, namely to find the object for a given utterance from a 3d scene presented as object-centric semantic information.\n'
system_message = ""
system_message += get_system_message()
system_message += get_principle()
openai_config = {
'model': 'gpt-4',
'temperature': 1e-7,
'top_p': 1e-7,
# 'max_tokens': 4096,
'max_tokens': 8192,
'system_message': system_message,
# 'load_path': '',
'save_path': 'chats',
'debug': True
}
code_interpreter = CodeInterpreter(**openai_config)
prompt = gen_prompt("Find the chair next to the table.", "scene0132_00")
print(prompt)
response = get_gpt_response(prompt, code_interpreter)
# print(response)
print("-------pretext--------")
print(code_interpreter.pretext)