Spaces:
Sleeping
Sleeping
kudo1026
commited on
Commit
•
f27a827
1
Parent(s):
fbfbf30
initial
Browse files- .gitattributes +2 -0
- Dockerfile +52 -0
- README.md +11 -10
- app.py +189 -0
- appyibu.py +137 -0
- code_interpreter.py +132 -0
- display_model.py +167 -0
- gpt_dialogue.py +186 -0
- object_filter_gpt4.py +154 -0
- objects_info/objects_info_scene0132_00.npy +3 -0
- prompt_text.py +53 -0
- requirements.txt +6 -0
- scenes/scene0132_00_vh_clean_2_aligned.glb +3 -0
- scenes/scene0132_00_vh_clean_2_aligned.ply +3 -0
- sources.list +4 -0
- transcrib3d_main.py +285 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
scenes/scene0132_00_vh_clean_2_aligned.glb filter=lfs diff=lfs merge=lfs -text
|
37 |
+
scenes/scene0132_00_vh_clean_2_aligned.ply filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 使用ubuntu22.04作为基础镜像
|
2 |
+
FROM ubuntu:22.04
|
3 |
+
|
4 |
+
# 设置工作目录为/Transcrib3D-Demo/
|
5 |
+
WORKDIR /code
|
6 |
+
|
7 |
+
COPY ./requirements.txt /code/requirements.txt
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
# 无vpn时需要换源:
|
13 |
+
# COPY ./sources.list /etc/apt/
|
14 |
+
# RUN mkdir -p ~/.pip
|
15 |
+
# RUN echo "[global]" >> ~/.pip/pip.conf
|
16 |
+
# RUN echo "index-url = https://pypi.tuna.tsinghua.edu.cn/simple" >> ~/.pip/pip.conf
|
17 |
+
# RUN cat ~/.pip/pip.conf
|
18 |
+
|
19 |
+
RUN apt-get update
|
20 |
+
RUN apt-get install -y wget
|
21 |
+
RUN apt-get install -y python3-pip
|
22 |
+
RUN pip3 config list
|
23 |
+
RUN apt-get install -y sudo
|
24 |
+
RUN apt-get install -y vim
|
25 |
+
|
26 |
+
# 在/root/Downloads目录下下载libssl包
|
27 |
+
RUN mkdir -p /root/Downloads && \
|
28 |
+
wget -P /root/Downloads http://archive.ubuntu.com/ubuntu/pool/main/o/openssl1.0/libssl1.0.0_1.0.2n-1ubuntu5_amd64.deb && \
|
29 |
+
dpkg -i /root/Downloads/libssl1.0.0_1.0.2n-1ubuntu5_amd64.deb
|
30 |
+
|
31 |
+
# 安装requirements.txt中的Python依赖项
|
32 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
33 |
+
|
34 |
+
# Set up a new user named "user" with user ID 1000
|
35 |
+
RUN useradd -m -u 1000 user
|
36 |
+
|
37 |
+
# Switch to the "user" user
|
38 |
+
USER user
|
39 |
+
|
40 |
+
# Set home to the user's home directory
|
41 |
+
ENV HOME=/home/user \
|
42 |
+
PATH=/home/user/.local/bin:$PATH
|
43 |
+
|
44 |
+
# Set the working directory to the user's home directory
|
45 |
+
WORKDIR $HOME/app
|
46 |
+
|
47 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
48 |
+
COPY --chown=user . $HOME/app
|
49 |
+
|
50 |
+
# 可以设置容器启动后默认执行的命令,这里我们只是为示例,所以不设置任何启动命令
|
51 |
+
# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
52 |
+
CMD [ "python3", "app.py" ]
|
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
---
|
2 |
-
title: Transcrib3D
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
|
|
|
|
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Transcrib3D-Demo
|
3 |
+
emoji: 🎡
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
+
# sdk: gradio
|
9 |
+
# sdk_version: 4.25.0
|
10 |
+
# app_file: app.py
|
11 |
pinned: false
|
12 |
+
license: apache-2.0
|
13 |
---
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, time, threading
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from display_model import *
|
5 |
+
|
6 |
+
scan_id = "scene0132_00"
|
7 |
+
ply_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned.ply")
|
8 |
+
glb_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned.glb")
|
9 |
+
new_ply_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned_AddBox.ply")
|
10 |
+
new_glb_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned_AddBox.glb")
|
11 |
+
objects_info_file = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
|
12 |
+
|
13 |
+
def insert_user_none_between_assistant(messages):
|
14 |
+
# 初始化结果列表
|
15 |
+
result = []
|
16 |
+
# 初始状态设置为"user",以确保列表第一个条目为"assistant"时能正确插入
|
17 |
+
last_role = "user"
|
18 |
+
|
19 |
+
for msg in messages:
|
20 |
+
# 检查当前信息的角色
|
21 |
+
current_role = msg["role"]
|
22 |
+
|
23 |
+
# 如果上一个和当前信息均为"assistant",插入content为None的"user"信息
|
24 |
+
if last_role == "assistant" and current_role == "assistant":
|
25 |
+
result.append({"role": "user", "content": None})
|
26 |
+
|
27 |
+
# 将当前信息添加到结果列表
|
28 |
+
result.append(msg)
|
29 |
+
|
30 |
+
# 更新上一条信息的角色
|
31 |
+
last_role = current_role
|
32 |
+
|
33 |
+
return result
|
34 |
+
|
35 |
+
def timer_check_update(code_interpreter, update_interval, stop_event):
|
36 |
+
"""
|
37 |
+
定时检查 code_interpreter.has_update 是否为True,
|
38 |
+
如果为True,则触发界面更新逻辑并重置状态。
|
39 |
+
参数:
|
40 |
+
- code_interpreter: CodeInterpreter的实例,预期包含has_update属性。
|
41 |
+
- update_interval: 定时器检查间隔,以秒为单位。
|
42 |
+
- stop_event: 一个threading.Event()实例,用于停止定时器线程。
|
43 |
+
"""
|
44 |
+
while not stop_event.is_set():
|
45 |
+
if code_interpreter.has_update:
|
46 |
+
# 实现更新界面显示的逻辑
|
47 |
+
print("Detected update, trigger UI refreshing...")
|
48 |
+
# 在这里添加更新界面显示的代码
|
49 |
+
# ...
|
50 |
+
# 重置has_update状态
|
51 |
+
code_interpreter.has_update = False
|
52 |
+
|
53 |
+
# 等待下次检查
|
54 |
+
time.sleep(update_interval)
|
55 |
+
|
56 |
+
def process_instruction_callback(inp_api_key, instruction, llm_name):
|
57 |
+
|
58 |
+
if not inp_api_key:
|
59 |
+
print("Please input OpenAI API Key.")
|
60 |
+
return
|
61 |
+
else:
|
62 |
+
os.environ["OPENAI_API_KEY"] = inp_api_key
|
63 |
+
from transcrib3d_main import gen_prompt, get_gpt_response, get_openai_config
|
64 |
+
from code_interpreter import CodeInterpreter
|
65 |
+
|
66 |
+
print("llm_name:",llm_name)
|
67 |
+
# generate prompt from user instruction
|
68 |
+
# scan_id = "scene0132_00"
|
69 |
+
prompt = gen_prompt(instruction, scan_id)
|
70 |
+
|
71 |
+
# get oepnai config
|
72 |
+
openai_config = get_openai_config(llm_name)
|
73 |
+
|
74 |
+
# get LLM response
|
75 |
+
|
76 |
+
code_interpreter = CodeInterpreter(**openai_config)
|
77 |
+
get_gpt_response(prompt, code_interpreter)
|
78 |
+
messages = code_interpreter.pretext
|
79 |
+
|
80 |
+
# draw the answer bounding box to the scene
|
81 |
+
generate_answer_glb(messages[-1]['content'])
|
82 |
+
# model3d.update(value=new_glb_file)
|
83 |
+
|
84 |
+
# form gradio chat history
|
85 |
+
messages = insert_user_none_between_assistant(messages[1:])
|
86 |
+
gradio_messages = []
|
87 |
+
for idx in range(int(len(messages)/2)):
|
88 |
+
gradio_message = [messages[idx*2]['content'], messages[idx*2+1]['content']]
|
89 |
+
gradio_messages.append(gradio_message)
|
90 |
+
|
91 |
+
# return gradio_messages
|
92 |
+
return new_glb_file, gradio_messages
|
93 |
+
|
94 |
+
def generate_answer_glb(answer_content):
|
95 |
+
from transcrib3d_main import extract_answer_id_from_last_line
|
96 |
+
last_line = answer_content.splitlines()[-1] if len(answer_content) > 0 else ''
|
97 |
+
answer_id, _ = extract_answer_id_from_last_line(last_line)
|
98 |
+
print("extracted answer id:", answer_id)
|
99 |
+
|
100 |
+
# get the bounding box of the answer object
|
101 |
+
|
102 |
+
box = np.load(objects_info_file, allow_pickle=True)[answer_id]['extension']
|
103 |
+
print("box extension:",box)
|
104 |
+
|
105 |
+
# add the box to ply
|
106 |
+
add_1box_to_ply(box, ply_file, new_ply_file)
|
107 |
+
ply_to_glb(new_ply_file, new_glb_file)
|
108 |
+
|
109 |
+
def llm_dropdown_callback(llm_name):
|
110 |
+
print("type in callback:",type(llm_name))
|
111 |
+
llm_name = str(llm_name)
|
112 |
+
print("llm_name in callback:",llm_name)
|
113 |
+
return llm_name
|
114 |
+
|
115 |
+
with gr.Blocks() as demo:
|
116 |
+
gr.Markdown("## Transcrib3D-Demo")
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column():
|
119 |
+
model3d = gr.Model3D(
|
120 |
+
value="scenes/scene0132_00_vh_clean_2_aligned.glb",
|
121 |
+
label="ScanNet-scene0132_00",
|
122 |
+
camera_position=(90,120,8),
|
123 |
+
zoom_speed=0.25,
|
124 |
+
# height=635,
|
125 |
+
height=725
|
126 |
+
)
|
127 |
+
# print("Type1:",type(model3d))
|
128 |
+
|
129 |
+
# gr.Markdown("🖱️:arrow_up::arrow_down:: SCROLL to zoom in/out.\t🖱️🔁 DRAG to rotate.\tCTRL+🖱️🔁 Press CTRL and DRAG to pan.")
|
130 |
+
html_content = """
|
131 |
+
<div style='text-align: center;'>
|
132 |
+
🖱️🔼🔽: SCROLL to zoom in/out. 🖱️🔁: DRAG to rotate. [CTRL]+🖱️🔁: Press CTRL and DRAG to pan.
|
133 |
+
</div>
|
134 |
+
"""
|
135 |
+
gr.HTML(value=html_content)
|
136 |
+
|
137 |
+
with gr.Column():
|
138 |
+
|
139 |
+
inp_api_key = gr.Textbox(label='OpenAI API Key (this is not stored anywhere)', lines=1)
|
140 |
+
|
141 |
+
llm_dropdown = gr.Dropdown(
|
142 |
+
# choices=['gpt-4-turbo','gpt-4','gpt-3.5-turbo'],
|
143 |
+
choices=['gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-3.5-turbo-0125'],
|
144 |
+
label="LLM Selection",
|
145 |
+
type='value'
|
146 |
+
)
|
147 |
+
# llm_name = "gpt-4-turbo"
|
148 |
+
llm_name_text = gr.Text(visible=False)
|
149 |
+
llm_dropdown.select(fn=llm_dropdown_callback, inputs=llm_dropdown, outputs=llm_name_text)
|
150 |
+
|
151 |
+
|
152 |
+
user_instruction_textbox = gr.Textbox(
|
153 |
+
label="Instruction",
|
154 |
+
placeholder="Describe an object in the scene with its attributes and its relation with other objects, e.g. 'The largest table in the scene.",
|
155 |
+
# scale=4
|
156 |
+
)
|
157 |
+
bt = gr.Button(
|
158 |
+
value="Submit",
|
159 |
+
# scale=1
|
160 |
+
)
|
161 |
+
|
162 |
+
dialogue = gr.Chatbot(
|
163 |
+
height=470
|
164 |
+
# value = [["1","2"], [None, '3']]
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
# print("Type2:",type(model3d))
|
169 |
+
# 直接在 inputs列表里写model3d,会导致实际传给callback函数的是str
|
170 |
+
# bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
|
171 |
+
bt.click(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox,llm_name_text], outputs=[model3d,dialogue])
|
172 |
+
user_instruction_textbox.submit(fn=process_instruction_callback, inputs=[inp_api_key, user_instruction_textbox, llm_name_text], outputs=[model3d,dialogue])
|
173 |
+
|
174 |
+
# 直接用lambda函数定义一个映射
|
175 |
+
# type(user_instruction_textbox.value)
|
176 |
+
# user_instruction_textbox.
|
177 |
+
# user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
178 |
+
# user_instruction_textbox.
|
179 |
+
# bt.click(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
180 |
+
|
181 |
+
# os.system('uname -a') # 显示所有系统信息
|
182 |
+
# demo.launch()
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
# demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
|
189 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
appyibu.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import threading
|
3 |
+
import gradio as gr
|
4 |
+
from transcrib3d_main import gen_prompt, get_gpt_response, get_openai_config, extract_answer_id_from_last_line
|
5 |
+
from code_interpreter import CodeInterpreter
|
6 |
+
from display_model import *
|
7 |
+
|
8 |
+
scan_id = "scene0132_00"
|
9 |
+
ply_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned.ply")
|
10 |
+
glb_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned.glb")
|
11 |
+
new_ply_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned_AddBox.ply")
|
12 |
+
new_glb_file = os.path.join("scenes", f"{scan_id}_vh_clean_2_aligned_AddBox.glb")
|
13 |
+
objects_info_file = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
|
14 |
+
|
15 |
+
def insert_user_none_between_assistant(messages):
|
16 |
+
# 初始化结果列表
|
17 |
+
result = []
|
18 |
+
# 初始状态设置为"user",以确保列表第一个条目为"assistant"时能正确插入
|
19 |
+
last_role = "user"
|
20 |
+
|
21 |
+
for msg in messages:
|
22 |
+
# 检查当前信息的角色
|
23 |
+
current_role = msg["role"]
|
24 |
+
|
25 |
+
# 如果上一个和当前信息均为"assistant",插入content为None的"user"信息
|
26 |
+
if last_role == "assistant" and current_role == "assistant":
|
27 |
+
result.append({"role": "user", "content": None})
|
28 |
+
|
29 |
+
# 将当前信息添加到结果列表
|
30 |
+
result.append(msg)
|
31 |
+
|
32 |
+
# 更新上一条信息的角色
|
33 |
+
last_role = current_role
|
34 |
+
|
35 |
+
return result
|
36 |
+
|
37 |
+
def generate_answer_glb(answer_content):
|
38 |
+
last_line = answer_content.splitlines()[-1] if len(answer_content) > 0 else ''
|
39 |
+
answer_id, _ = extract_answer_id_from_last_line(last_line)
|
40 |
+
print("extracted answer id:", answer_id)
|
41 |
+
|
42 |
+
# get the bounding box of the answer object
|
43 |
+
|
44 |
+
box = np.load(objects_info_file, allow_pickle=True)[answer_id]['extension']
|
45 |
+
print("box extension:",box)
|
46 |
+
|
47 |
+
# add the box to ply
|
48 |
+
add_1box_to_ply(box, ply_file, new_ply_file)
|
49 |
+
ply_to_glb(new_ply_file, new_glb_file)
|
50 |
+
|
51 |
+
def run_inferring(instruction, model3d, dialogue):
|
52 |
+
# generate prompt from user instruction
|
53 |
+
# scan_id = "scene0132_00"
|
54 |
+
prompt = gen_prompt(instruction, scan_id)
|
55 |
+
|
56 |
+
# get oepnai config
|
57 |
+
openai_config = get_openai_config()
|
58 |
+
|
59 |
+
# get LLM response
|
60 |
+
code_interpreter = CodeInterpreter(**openai_config)
|
61 |
+
get_gpt_response(prompt, code_interpreter)
|
62 |
+
messages = code_interpreter.pretext
|
63 |
+
|
64 |
+
# draw the answer bounding box to the scene
|
65 |
+
generate_answer_glb(messages[-1]['content'])
|
66 |
+
# global model3d
|
67 |
+
# print(model3d.value)
|
68 |
+
# model3d.postprocess(new_glb_file)
|
69 |
+
# print(model3d.value)
|
70 |
+
|
71 |
+
# form gradio chat history
|
72 |
+
messages = insert_user_none_between_assistant(messages[1:])
|
73 |
+
# print(len(messages))
|
74 |
+
# print(messages)
|
75 |
+
gradio_messages = []
|
76 |
+
for idx in range(int(len(messages)/2)):
|
77 |
+
gradio_message = [messages[idx*2]['content'], messages[idx*2+1]['content']]
|
78 |
+
gradio_messages.append(gradio_message)
|
79 |
+
|
80 |
+
# return new_glb_file, gradio_messages
|
81 |
+
model3d.update(value=new_glb_file)
|
82 |
+
dialogue.update(gradio_messages)
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
def process_instruction_callback(user_instruction, model3d, dialogue):
|
87 |
+
threading.Thread(target=run_inferring, args=(user_instruction, model3d, dialogue)).start()
|
88 |
+
# return "Processing your instruction, please wait...",
|
89 |
+
|
90 |
+
with gr.Blocks() as demo:
|
91 |
+
gr.Markdown("## Transcrib3D-Demo")
|
92 |
+
with gr.Row():
|
93 |
+
model3d = gr.Model3D(
|
94 |
+
value="scenes/scene0132_00_vh_clean_2_aligned.glb",
|
95 |
+
# value="scenes/scene0132_00_vh_clean_2_aligned_AddBox.glb",
|
96 |
+
# value="scenes/scene0132_00_vh_clean_2_aligned.ply",
|
97 |
+
# value="scenes/scene0132_00_vh_clean_2_aligned.obj",
|
98 |
+
# value="scenes/scene0132_00_gt_bboxes_aligned.ply",
|
99 |
+
# value="scenes/cube.ply",
|
100 |
+
label="ScanNet-scene0132_00",
|
101 |
+
camera_position=(90,120,8),
|
102 |
+
zoom_speed=0.25,
|
103 |
+
height=635
|
104 |
+
)
|
105 |
+
# print("Type1:",type(model3d))
|
106 |
+
|
107 |
+
with gr.Column():
|
108 |
+
# with gr.Row():
|
109 |
+
user_instruction_textbox = gr.Textbox(
|
110 |
+
label="Instruction",
|
111 |
+
placeholder="Describe an object in the scene with its attributes and its relation with other objects.",
|
112 |
+
# scale=4
|
113 |
+
)
|
114 |
+
bt = gr.Button(
|
115 |
+
value="Submit",
|
116 |
+
# scale=1
|
117 |
+
)
|
118 |
+
|
119 |
+
dialogue = gr.Chatbot(
|
120 |
+
height=470
|
121 |
+
# value = [["1","2"], [None, '3']]
|
122 |
+
)
|
123 |
+
|
124 |
+
# print("Type2:",type(model3d))
|
125 |
+
# 直接在 inputs列表里写model3d,会导致实际传给callback函数的是str
|
126 |
+
# bt.click(fn=process_instruction_callback, inputs=user_instruction_textbox, outputs=dialogue)
|
127 |
+
bt.click(fn=process_instruction_callback, inputs=[user_instruction_textbox, gr.State(model3d), gr.State(dialogue)])#, outputs=[model3d,dialogue])
|
128 |
+
|
129 |
+
# 直接用lambda函数定义一个映射
|
130 |
+
# type(user_instruction_textbox.value)
|
131 |
+
# user_instruction_textbox.
|
132 |
+
# user_instruction_textbox.submit(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
133 |
+
# user_instruction_textbox.
|
134 |
+
# bt.click(fn=lambda: process_instruction_callback(user_instruction_textbox, model3d), inputs=[], outputs=dialogue)
|
135 |
+
|
136 |
+
|
137 |
+
demo.launch()
|
code_interpreter.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, datetime, sys
|
2 |
+
from io import StringIO
|
3 |
+
from contextlib import redirect_stdout
|
4 |
+
import traceback
|
5 |
+
# import openai
|
6 |
+
from gpt_dialogue import Dialogue
|
7 |
+
# openai.api_key = os.getenv("OPENAI_API_KEY")
|
8 |
+
|
9 |
+
class CodeInterpreter(Dialogue):
|
10 |
+
|
11 |
+
def __init__(self, **kwargs):
|
12 |
+
super().__init__(**kwargs)
|
13 |
+
|
14 |
+
def call_openai_with_code_interpreter(self, user_prompt,namespace_for_exec={},token_usage_total=0):
|
15 |
+
# 如果gpt回复的内容包含python代码,则把代码的执行结果发送给gpt,继续等待其回复
|
16 |
+
# 如果gpt回复的内容不包含python代码,则此函数返回全部结果
|
17 |
+
# 每次递归统计使用的token数,最终返回总的token数
|
18 |
+
assistant_response,token_usage = self.call_openai(user_prompt)
|
19 |
+
token_usage_total+=token_usage
|
20 |
+
|
21 |
+
# check if response contain code snippet
|
22 |
+
response_content = assistant_response['content']
|
23 |
+
if self.debug:
|
24 |
+
print('response_content: ', response_content)
|
25 |
+
response_splits = response_content.split('```python')
|
26 |
+
if len(response_splits) <= 1:
|
27 |
+
# no code snippet found, return the raw response
|
28 |
+
if self.debug:
|
29 |
+
print('no code snippet found, return the raw response')
|
30 |
+
return assistant_response,token_usage_total
|
31 |
+
else:
|
32 |
+
# code snippet found, execute the code
|
33 |
+
# code_snippet = response_splits[-1].split('```')[0]
|
34 |
+
# print('code snippet: ', code_snippet)
|
35 |
+
code_snippet=""
|
36 |
+
for split in response_splits:
|
37 |
+
if '```' in split:
|
38 |
+
code_snippet+=split.split('```')[0]
|
39 |
+
f = StringIO()
|
40 |
+
# sys.stdout = f
|
41 |
+
code_exec_success=True
|
42 |
+
|
43 |
+
with redirect_stdout(f):
|
44 |
+
try:
|
45 |
+
exec(code_snippet,namespace_for_exec)
|
46 |
+
code_exe_result = f.getvalue()
|
47 |
+
except Exception as e:
|
48 |
+
code_exec_success=False
|
49 |
+
traceback_message_lines=traceback.format_exc().splitlines()
|
50 |
+
code_exe_result = '\n'.join(traceback_message_lines[-4:])
|
51 |
+
# code_exe_result = f.getvalue()
|
52 |
+
# f.close()
|
53 |
+
# sys.stdout = sys.__stdout__
|
54 |
+
|
55 |
+
#############利用保存文件的方式####################
|
56 |
+
# # 将代码片段保存到 code_snippet.py 文件
|
57 |
+
# with open("code_snippet.py", "w") as file:
|
58 |
+
# file.write(code_snippet)
|
59 |
+
|
60 |
+
# # 执行 code_snippet.py 并将输出重定向到临时文件
|
61 |
+
# os.system("python code_snippet.py > output.txt")
|
62 |
+
|
63 |
+
# # 从临时文件中读取结果
|
64 |
+
# with open("output.txt", "r") as file:
|
65 |
+
# code_exe_result = file.read()
|
66 |
+
##################################################
|
67 |
+
if code_exec_success:
|
68 |
+
code_exe_msg='code execution result:\n' + str(code_exe_result)
|
69 |
+
else:
|
70 |
+
code_exe_msg = "An error was raised when executing the code you write: %s"%code_exe_result
|
71 |
+
# code_exe_msg = 'Execution result of the above code is: ' + str(code_exe_result)
|
72 |
+
print(code_exe_msg)
|
73 |
+
return self.call_openai_with_code_interpreter(code_exe_msg,namespace_for_exec,token_usage_total)
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
|
77 |
+
config = {
|
78 |
+
'model': 'gpt-4',
|
79 |
+
# 'model': 'gpt-3.5-turbo',
|
80 |
+
'temperature': 0,
|
81 |
+
'top_p': 0.0,
|
82 |
+
'max_tokens': 'inf',
|
83 |
+
'system_message': "Imagine you are an artificial intelligence assitant with a python interpreter. So when answering questions, you can choose to generate python code (for example, when there is need to do quantitative evaluation). The generated code should always print out the result. The code should be written in python and should be able to run in the python environment with the following packages installed: numpy, math. The generated code should be complete and always include proper imports. Each generated code piece should be independent and NOT rely on previous generated code. When answer step by step, stop whenever you feel there is need to generate python code (for example, where there is need to do quantitative evaluation) and wait for the result from the code execution. When the answewr is complete, add 'Now the answer is complete.' to the end of your answer.",
|
84 |
+
|
85 |
+
# 'load_path': '',
|
86 |
+
'save_path': 'chats',
|
87 |
+
'debug': False
|
88 |
+
}
|
89 |
+
|
90 |
+
dialogue = CodeInterpreter(**config)
|
91 |
+
print('======================Instructions======================')
|
92 |
+
print('Type "exit" to exit the dialogue')
|
93 |
+
print('Type "reset" to reset the dialogue')
|
94 |
+
print('Type "pretext" to see the current dialogue history')
|
95 |
+
print('Type "config" to see the current config')
|
96 |
+
print('Type "save" to save the current dialogue history')
|
97 |
+
print('====GPT Dialogue Initialized, start asking your questions====')
|
98 |
+
|
99 |
+
while True:
|
100 |
+
user_prompt = input('You: ')
|
101 |
+
if user_prompt == 'exit':
|
102 |
+
break
|
103 |
+
elif user_prompt == 'reset':
|
104 |
+
dialogue = CodeInterpreter(**config)
|
105 |
+
print('====GPT Dialogue Initialized, start asking your questions====')
|
106 |
+
continue
|
107 |
+
elif user_prompt == 'pretext':
|
108 |
+
print('===Pretext===')
|
109 |
+
for message in dialogue.get_pretext():
|
110 |
+
print(message)
|
111 |
+
print('===Pretext===')
|
112 |
+
continue
|
113 |
+
elif user_prompt == 'config':
|
114 |
+
print('===Config===')
|
115 |
+
print(config)
|
116 |
+
print('===Config===')
|
117 |
+
continue
|
118 |
+
elif user_prompt == 'save':
|
119 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
120 |
+
dialogue.save_pretext(config['save_path'], timestamp)
|
121 |
+
print('Pretext saved to', os.path.join(
|
122 |
+
config['save_path'], 'dialogue_' + timestamp + '.json'))
|
123 |
+
continue
|
124 |
+
else:
|
125 |
+
# response = dialogue.call_openai(user_prompt)['content']
|
126 |
+
response = dialogue.call_openai_with_code_interpreter(user_prompt)['content']
|
127 |
+
print('Bot:', response)
|
128 |
+
counter = 0
|
129 |
+
while not response.endswith('Now the answer is complete.') and counter < 10:
|
130 |
+
response = dialogue.call_openai_with_code_interpreter('')['content']
|
131 |
+
print('Bot:', response)
|
132 |
+
counter += 1
|
display_model.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from plyfile import PlyData, PlyElement
|
3 |
+
|
4 |
+
def ply_to_glb(ply_file, glb_file):
|
5 |
+
print("Converting PLY to GLB...")
|
6 |
+
# Import trimesh here to ensure it's only required when this function is called
|
7 |
+
import trimesh
|
8 |
+
# Load the PLY file with trimesh
|
9 |
+
try:
|
10 |
+
mesh = trimesh.load(ply_file)
|
11 |
+
# Export to GLB format
|
12 |
+
mesh.export(glb_file, file_type='glb')
|
13 |
+
print("Conversion finished.")
|
14 |
+
return "PLY to GLB conversion complete."
|
15 |
+
except Exception as e:
|
16 |
+
# In case of any issue, print the error
|
17 |
+
print(f"Error during conversion: {e}")
|
18 |
+
return "Conversion failed."
|
19 |
+
|
20 |
+
|
21 |
+
def merge_box_to_ply(ply_file, box_ply_file):
|
22 |
+
pass
|
23 |
+
|
24 |
+
def add_1box_to_ply(box, ply_file, new_ply_file, line_width=0.05, obj_id=1):
|
25 |
+
print("adding 1 box to ply...")
|
26 |
+
print("ply_file:",ply_file)
|
27 |
+
print("new_ply_file:",new_ply_file)
|
28 |
+
# box format: [xmin, ymin, zmin, xmax, ymax, zmax]
|
29 |
+
xmin, ymin, zmin, xmax, ymax, zmax = box
|
30 |
+
box_coords = np.array(
|
31 |
+
[[xmin, ymin, zmin], #0
|
32 |
+
[xmin-line_width, ymin-line_width, zmin], #1
|
33 |
+
[xmax, ymin, zmin], #2
|
34 |
+
[xmax+line_width, ymin-line_width, zmin], #3
|
35 |
+
[xmax, ymax, zmin], #4
|
36 |
+
[xmax+line_width, ymax+line_width, zmin], #5
|
37 |
+
[xmin, ymax, zmin], #6
|
38 |
+
[xmin-line_width, ymax+line_width, zmin], #7
|
39 |
+
[xmin, ymin, zmax], #8
|
40 |
+
[xmin-line_width, ymin-line_width, zmax], #9
|
41 |
+
[xmax, ymin, zmax], #10
|
42 |
+
[xmax+line_width, ymin-line_width, zmax], #11
|
43 |
+
[xmax, ymax, zmax], #12
|
44 |
+
[xmax+line_width, ymax+line_width, zmax], #13
|
45 |
+
[xmin, ymax, zmax], #14
|
46 |
+
[xmin-line_width, ymax+line_width, zmax] #15
|
47 |
+
])
|
48 |
+
|
49 |
+
# read in ply
|
50 |
+
with open(ply_file, 'rb') as f:
|
51 |
+
ply_data = PlyData.read(f)
|
52 |
+
vertices = ply_data['vertex'].data
|
53 |
+
|
54 |
+
# handle vertices and update
|
55 |
+
color = [255, 0, 0]
|
56 |
+
box_vertices = np.zeros(len(box_coords), dtype=vertices.dtype)
|
57 |
+
box_vertices['x'] = [coord[0] for coord in box_coords]
|
58 |
+
box_vertices['y'] = [coord[1] for coord in box_coords]
|
59 |
+
box_vertices['z'] = [coord[2] for coord in box_coords]
|
60 |
+
box_vertices['red'] = [color[0]] * 16
|
61 |
+
box_vertices['green'] = [color[1]] * 16
|
62 |
+
box_vertices['blue'] = [color[2]] * 16
|
63 |
+
box_vertices['alpha'] = [obj_id] * 16
|
64 |
+
|
65 |
+
# 将新的顶点数据添加到原始顶点数据后面
|
66 |
+
updated_vertices = np.concatenate((vertices, box_vertices))
|
67 |
+
|
68 |
+
# 创建包含新顶点的PlyElement对象
|
69 |
+
updated_vertex_element = PlyElement.describe(updated_vertices, 'vertex')
|
70 |
+
|
71 |
+
# 将更新后的PlyElement对象替换原始的顶点数据
|
72 |
+
# ply_data['vertex'] = updated_vertex_element
|
73 |
+
|
74 |
+
# get the number of original vertices:
|
75 |
+
num_origin_vertices = len(vertices)
|
76 |
+
|
77 |
+
# define connections of new faces
|
78 |
+
|
79 |
+
box_connections=[
|
80 |
+
[num_origin_vertices+0, num_origin_vertices+1, num_origin_vertices+3 ],
|
81 |
+
[num_origin_vertices+0, num_origin_vertices+3, num_origin_vertices+2 ],
|
82 |
+
[num_origin_vertices+2, num_origin_vertices+3, num_origin_vertices+5 ],
|
83 |
+
[num_origin_vertices+2, num_origin_vertices+5, num_origin_vertices+4 ],
|
84 |
+
[num_origin_vertices+4, num_origin_vertices+5, num_origin_vertices+7 ],
|
85 |
+
[num_origin_vertices+4, num_origin_vertices+7, num_origin_vertices+6 ],
|
86 |
+
[num_origin_vertices+0, num_origin_vertices+1, num_origin_vertices+7 ],
|
87 |
+
[num_origin_vertices+0, num_origin_vertices+7, num_origin_vertices+6 ],
|
88 |
+
[num_origin_vertices+0, num_origin_vertices+1, num_origin_vertices+9 ],
|
89 |
+
[num_origin_vertices+0, num_origin_vertices+9, num_origin_vertices+8 ],
|
90 |
+
[num_origin_vertices+2, num_origin_vertices+3, num_origin_vertices+11],
|
91 |
+
[num_origin_vertices+2, num_origin_vertices+11, num_origin_vertices+10],
|
92 |
+
[num_origin_vertices+4, num_origin_vertices+5, num_origin_vertices+13],
|
93 |
+
[num_origin_vertices+4, num_origin_vertices+13, num_origin_vertices+12],
|
94 |
+
[num_origin_vertices+6, num_origin_vertices+7, num_origin_vertices+15],
|
95 |
+
[num_origin_vertices+6, num_origin_vertices+15, num_origin_vertices+14],
|
96 |
+
[num_origin_vertices+8, num_origin_vertices+9, num_origin_vertices+11],
|
97 |
+
[num_origin_vertices+8, num_origin_vertices+11, num_origin_vertices+10],
|
98 |
+
[num_origin_vertices+10, num_origin_vertices+11, num_origin_vertices+13],
|
99 |
+
[num_origin_vertices+10, num_origin_vertices+13, num_origin_vertices+12],
|
100 |
+
[num_origin_vertices+12, num_origin_vertices+13, num_origin_vertices+15],
|
101 |
+
[num_origin_vertices+12, num_origin_vertices+15, num_origin_vertices+14],
|
102 |
+
[num_origin_vertices+8, num_origin_vertices+9, num_origin_vertices+15],
|
103 |
+
[num_origin_vertices+8, num_origin_vertices+15, num_origin_vertices+14]
|
104 |
+
]
|
105 |
+
|
106 |
+
# handle faces and update
|
107 |
+
faces = ply_data['face'].data
|
108 |
+
box_faces = np.zeros(len(box_connections), dtype=faces.dtype)
|
109 |
+
box_faces['vertex_indices'] = box_connections
|
110 |
+
|
111 |
+
# 将新的face数据添加到原始顶点数据后面
|
112 |
+
updated_faces = np.concatenate((faces, box_faces))
|
113 |
+
|
114 |
+
# 创建包含新顶点的PlyElement对象
|
115 |
+
updated_face_element = PlyElement.describe(updated_faces, 'face')
|
116 |
+
|
117 |
+
# 将更新后的PlyElement对象替换原始的顶点数据
|
118 |
+
# ply_data['face'] = updated_face_element
|
119 |
+
|
120 |
+
new_ply_data = PlyData([updated_vertex_element, updated_face_element])
|
121 |
+
|
122 |
+
# 将更新后的PlyData对象写回Ply文件
|
123 |
+
with open(new_ply_file, 'wb') as f:
|
124 |
+
new_ply_data.write(f)
|
125 |
+
|
126 |
+
print("add 1 box to ply finished.")
|
127 |
+
|
128 |
+
|
129 |
+
def ply_to_obj(ply_file, obj_file, mtl_file):
|
130 |
+
# 读取PLY文件
|
131 |
+
with open(ply_file, 'rb') as f:
|
132 |
+
plydata = PlyData.read(f)
|
133 |
+
|
134 |
+
# 获取顶点和面数据
|
135 |
+
vertices = np.vstack([plydata['vertex'][prop] for prop in ['x', 'y', 'z']]).T
|
136 |
+
colors = np.vstack([plydata['vertex'][prop] for prop in ['red', 'green', 'blue', 'alpha']]).T/255.0
|
137 |
+
faces = plydata['face']['vertex_indices']
|
138 |
+
|
139 |
+
# 写入OBJ文件
|
140 |
+
with open(obj_file, 'w') as f:
|
141 |
+
# 写入依赖的mtl文件(颜色)
|
142 |
+
f.write("mtllib %s\n"%mtl_file.split('/')[-1])
|
143 |
+
|
144 |
+
# 写入顶点信息
|
145 |
+
for vertex in vertices:
|
146 |
+
f.write(f"v {' '.join(map(str, vertex))}\n")
|
147 |
+
|
148 |
+
# 写入颜色信息
|
149 |
+
for idx in range(len(vertices)):
|
150 |
+
f.write("usemtl mat%d\n"%(idx+1))
|
151 |
+
|
152 |
+
# 写入面信息
|
153 |
+
for face in faces:
|
154 |
+
f.write("f")
|
155 |
+
for vertex_index in face:
|
156 |
+
f.write(f" {vertex_index + 1}") # OBJ文件索引从1开始
|
157 |
+
f.write("\n")
|
158 |
+
|
159 |
+
# 写入mtl文件
|
160 |
+
with open(mtl_file, 'w') as f:
|
161 |
+
for idx, color in enumerate(colors):
|
162 |
+
f.write("newmtl mat%d\n" % (idx+1))
|
163 |
+
f.write("Kd %f %f %f\n\n" % (color[0], color[1],color[2]))
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
# ply_to_obj("./scenes/scene0132_00_vh_clean_2_aligned.ply", "./scenes/scene0132_00_vh_clean_2_aligned.obj", "./scenes/scene0132_00_vh_clean_2_aligned_colors.mtl")
|
167 |
+
add_1box_to_ply([0,0,0,1,1,1],"scenes\scene0132_00_vh_clean_2_aligned.ply","scenes\scene0132_00_add1box.ply")
|
gpt_dialogue.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import datetime
|
4 |
+
# import openai
|
5 |
+
# openai.api_key = os.getenv("OPENAI_API_KEY")
|
6 |
+
from openai import OpenAI
|
7 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
8 |
+
|
9 |
+
|
10 |
+
# HUGGINGFACE_MODELS = {
|
11 |
+
# 'meta-llama/Llama-2-7b-chat-hf',
|
12 |
+
# 'meta-llama/Llama-2-13b-chat-hf',
|
13 |
+
# 'meta-llama/Llama-2-70b-chat-hf',
|
14 |
+
# 'codellama/CodeLlama-7b-Instruct-hf',
|
15 |
+
# 'codellama/CodeLlama-13b-Instruct-hf',
|
16 |
+
# 'codellama/CodeLlama-34b-Instruct-hf',
|
17 |
+
# 'mistralai/Mistral-7B-Instruct-v0.1',
|
18 |
+
# }
|
19 |
+
|
20 |
+
|
21 |
+
class Dialogue:
|
22 |
+
def __init__(self, model='gpt-4', temperature=0, top_p=0.0, max_tokens=10, system_message='', load_path=None, save_path='chats', debug=False):
|
23 |
+
self.model = model
|
24 |
+
self.temperature = temperature
|
25 |
+
self.top_p = top_p
|
26 |
+
self.max_tokens = max_tokens
|
27 |
+
self.system_message = system_message
|
28 |
+
self.save_path = save_path
|
29 |
+
self.debug = debug
|
30 |
+
self.has_update = False
|
31 |
+
if load_path is not None:
|
32 |
+
self.load_pretext(load_path)
|
33 |
+
else:
|
34 |
+
self.pretext = [{"role": "system", "content": self.system_message}]
|
35 |
+
|
36 |
+
if 'llama' in self.model:
|
37 |
+
from hf_conversational import HuggingfaceConversational
|
38 |
+
from transformers import Conversation
|
39 |
+
self.conversational = HuggingfaceConversational(
|
40 |
+
model_name=self.model,
|
41 |
+
temperature=self.temperature,
|
42 |
+
top_p=self.top_p,
|
43 |
+
max_length=self.max_tokens
|
44 |
+
)
|
45 |
+
|
46 |
+
def load_pretext(self, load_path):
|
47 |
+
|
48 |
+
def load_json(load_path):
|
49 |
+
with open(load_path) as json_file:
|
50 |
+
return json.load(json_file)
|
51 |
+
|
52 |
+
self.pretext = []
|
53 |
+
if isinstance(load_path, list):
|
54 |
+
for path in load_path:
|
55 |
+
self.pretext += load_json(path)
|
56 |
+
elif isinstance(load_path, str):
|
57 |
+
self.pretext = load_json(load_path)
|
58 |
+
else:
|
59 |
+
raise Exception('load_path must be a list of strings or a string')
|
60 |
+
|
61 |
+
def get_pretext(self):
|
62 |
+
return self.pretext
|
63 |
+
|
64 |
+
# def save_pretext(self, save_path, timestamp):
|
65 |
+
# if not os.path.exists(save_path):
|
66 |
+
# os.makedirs(save_path)
|
67 |
+
# json_path = os.path.join(save_path, 'dialogue_' + timestamp + '.json')
|
68 |
+
# json_object = json.dumps(self.get_pretext(), indent=4)
|
69 |
+
# with open(json_path, 'w') as f:
|
70 |
+
# f.write(json_object)
|
71 |
+
|
72 |
+
def save_pretext(self, save_folder_path, file_name):
|
73 |
+
if not os.path.exists(save_folder_path):
|
74 |
+
os.makedirs(save_folder_path)
|
75 |
+
json_path = os.path.join(save_folder_path, file_name)
|
76 |
+
json_object = json.dumps(self.get_pretext(), indent=4)
|
77 |
+
with open(json_path, 'w') as f:
|
78 |
+
f.write(json_object)
|
79 |
+
|
80 |
+
def print_pretext(self,print_system_and_user_first_prompt=True,to_print_out=True):
|
81 |
+
# determine whether to print system message and user's first prompt
|
82 |
+
from copy import deepcopy
|
83 |
+
pretext=deepcopy(self.pretext)
|
84 |
+
if not print_system_and_user_first_prompt:
|
85 |
+
pretext=pretext[2:]
|
86 |
+
printed_pretext=''
|
87 |
+
# print pretext
|
88 |
+
for piece in pretext:
|
89 |
+
if to_print_out:
|
90 |
+
print('----------------->ROLE: '+piece['role']+'\t<-----------------')
|
91 |
+
print('CONTENT: '+piece['content'])
|
92 |
+
printed_pretext=printed_pretext+'----------------->\tROLE: '+piece['role']+'\t<-----------------\n'
|
93 |
+
printed_pretext=printed_pretext+'CONTENT: '+piece['content']+'\n'
|
94 |
+
self.printed_pretext=printed_pretext
|
95 |
+
|
96 |
+
def call_openai(self, user_prompt):
|
97 |
+
user_message = [{"role": "user", "content": user_prompt}]
|
98 |
+
messages = self.pretext + user_message
|
99 |
+
# print('messages: ', messages)
|
100 |
+
if 'gpt' in self.model:
|
101 |
+
completion = client.chat.completions.create(
|
102 |
+
model=self.model,
|
103 |
+
messages=self.pretext + user_message,
|
104 |
+
temperature=self.temperature,
|
105 |
+
top_p=self.top_p,
|
106 |
+
seed=42,
|
107 |
+
)
|
108 |
+
# completion = openai.ChatCompletion.create(
|
109 |
+
# model=self.model,
|
110 |
+
# messages=self.pretext + user_message,
|
111 |
+
# temperature=self.temperature,
|
112 |
+
# top_p=self.top_p,
|
113 |
+
# )
|
114 |
+
# print('completion: ', completion)
|
115 |
+
raw_response_message = completion.choices[0].message
|
116 |
+
assistant_response_message = {'role': raw_response_message.role, 'content': raw_response_message.content}
|
117 |
+
# print('assistant_response_message: ', assistant_response_message)
|
118 |
+
token_usage = completion.usage.total_tokens
|
119 |
+
elif 'llama' in self.model:
|
120 |
+
chat_completion_messages,token_usage = self.conversational(messages)
|
121 |
+
assistant_response_message = chat_completion_messages.messages[-1]
|
122 |
+
else:
|
123 |
+
raise Exception('model name {} not supported'.format(self.model))
|
124 |
+
|
125 |
+
self.pretext = self.pretext + user_message + [assistant_response_message]
|
126 |
+
self.has_update = True
|
127 |
+
|
128 |
+
return assistant_response_message, token_usage
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
|
133 |
+
config = {
|
134 |
+
# 'model': 'gpt-4-1106-preview',
|
135 |
+
# 'model': 'gpt-4',
|
136 |
+
'model': 'gpt-3.5-turbo-0125',
|
137 |
+
# 'model': 'gpt-3.5-turbo',
|
138 |
+
# 'model': 'meta-llama/Llama-2-7b-chat-hf',
|
139 |
+
'temperature': 0,
|
140 |
+
'top_p': 0.0,
|
141 |
+
'max_tokens': 8192,
|
142 |
+
'system_message': '',
|
143 |
+
# 'load_path': 'chats/dialogue_an apple.json',
|
144 |
+
'save_path': 'chats',
|
145 |
+
'debug': False
|
146 |
+
}
|
147 |
+
|
148 |
+
dialogue = Dialogue(**config)
|
149 |
+
print('======================Instructions======================')
|
150 |
+
print('Type "exit" to exit the dialogue')
|
151 |
+
print('Type "reset" to reset the dialogue')
|
152 |
+
print('Type "pretext" to see the current dialogue history')
|
153 |
+
print('Type "config" to see the current config')
|
154 |
+
print('Type "save" to save the current dialogue history')
|
155 |
+
print('====GPT Dialogue Initialized, start asking your questions====')
|
156 |
+
|
157 |
+
while True:
|
158 |
+
user_prompt = input('You: ')
|
159 |
+
if user_prompt == 'exit':
|
160 |
+
break
|
161 |
+
elif user_prompt == 'reset':
|
162 |
+
dialogue = Dialogue(**config)
|
163 |
+
print('====GPT Dialogue Initialized, start asking your questions====')
|
164 |
+
continue
|
165 |
+
elif user_prompt == 'pretext':
|
166 |
+
print('===Pretext===')
|
167 |
+
for message in dialogue.get_pretext():
|
168 |
+
print(message)
|
169 |
+
# dialogue.print_pretext()
|
170 |
+
print('===Pretext===')
|
171 |
+
continue
|
172 |
+
elif user_prompt == 'config':
|
173 |
+
print('===Config===')
|
174 |
+
print(config)
|
175 |
+
print('===Config===')
|
176 |
+
continue
|
177 |
+
elif user_prompt == 'save':
|
178 |
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
179 |
+
dialogue.save_pretext(config['save_path'], timestamp)
|
180 |
+
print('Pretext saved to', os.path.join(
|
181 |
+
config['save_path'], 'dialogue_' + timestamp + '.json'))
|
182 |
+
continue
|
183 |
+
else:
|
184 |
+
assistant_response_message, token_usage = dialogue.call_openai(user_prompt)
|
185 |
+
response = assistant_response_message['content']
|
186 |
+
print('Bot:', response)
|
object_filter_gpt4.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os,json
|
3 |
+
import re
|
4 |
+
import logging
|
5 |
+
from gpt_dialogue import Dialogue
|
6 |
+
import openai
|
7 |
+
from tenacity import (
|
8 |
+
retry,
|
9 |
+
before_sleep_log,
|
10 |
+
stop_after_attempt,
|
11 |
+
wait_random_exponential,
|
12 |
+
wait_exponential,
|
13 |
+
wait_exponential_jitter,
|
14 |
+
RetryError
|
15 |
+
) # for exponential backoff
|
16 |
+
|
17 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__+'logger')
|
20 |
+
logger.setLevel(logging.ERROR)
|
21 |
+
console_handler = logging.StreamHandler()
|
22 |
+
console_handler.setLevel(logging.ERROR)
|
23 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
24 |
+
console_handler.setFormatter(formatter)
|
25 |
+
logger.addHandler(console_handler)
|
26 |
+
|
27 |
+
class ObjectFilter(Dialogue):
|
28 |
+
def __init__(self, model='gpt-4'):
|
29 |
+
config = {
|
30 |
+
# 'model': 'gpt-4',
|
31 |
+
# 'model': 'gpt-4-1106-preview',
|
32 |
+
'model': model,
|
33 |
+
'temperature': 0,
|
34 |
+
'top_p': 0.0,
|
35 |
+
'max_tokens': 8192,
|
36 |
+
# 'load_path': './object_filter_pretext.json',
|
37 |
+
'load_path': './object_filter_pretext_new.json',
|
38 |
+
'debug': False
|
39 |
+
}
|
40 |
+
super().__init__(**config)
|
41 |
+
|
42 |
+
def extract_all_int_lists_from_text(self,text) ->list:
|
43 |
+
# 匹配方括号内的内容
|
44 |
+
pattern = r'\[([^\[\]]+)\]'
|
45 |
+
matches = re.findall(pattern, text)
|
46 |
+
|
47 |
+
int_lists = []
|
48 |
+
|
49 |
+
for match in matches:
|
50 |
+
elements = match.split(',')
|
51 |
+
int_list = []
|
52 |
+
|
53 |
+
for element in elements:
|
54 |
+
element = element.strip()
|
55 |
+
try:
|
56 |
+
int_value = int(element)
|
57 |
+
int_list.append(int_value)
|
58 |
+
except ValueError:
|
59 |
+
pass
|
60 |
+
|
61 |
+
if len(int_list) == len(elements):
|
62 |
+
int_lists = int_lists + int_list
|
63 |
+
|
64 |
+
return int_lists
|
65 |
+
|
66 |
+
def extract_dict_from_text(self,text) ->dict:
|
67 |
+
# Use regular expression to match the dictionary in the text
|
68 |
+
match = re.search(r'{\s*(.*?)\s*}', text)
|
69 |
+
if match:
|
70 |
+
# Get the matched dictionary content
|
71 |
+
dict_str = match.group(1)
|
72 |
+
# Convert the dictionary string to an actual dictionary object
|
73 |
+
try:
|
74 |
+
result_dict = eval('{' + dict_str + '}')
|
75 |
+
return result_dict
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error converting string to dictionary: {e}")
|
78 |
+
return None
|
79 |
+
else:
|
80 |
+
print("No dictionary found in the given text.")
|
81 |
+
return None
|
82 |
+
|
83 |
+
@retry(wait=wait_exponential_jitter(initial=20, max=120, jitter=20), stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger,logging.ERROR)) #20s,40s,80s,120s + random.uniform(0,20)
|
84 |
+
def filter_objects_by_description(self,description,use_npy_file,objects_info_path=None,object_info_list=None,to_print=True):
|
85 |
+
# first, create the prompt
|
86 |
+
print("looking for relevant objects based on description:\n'%s'"%description)
|
87 |
+
prompt=""
|
88 |
+
prompt=prompt+"description:\n'%s'\nobject list:\n"%description
|
89 |
+
# load object info data and add to prompt
|
90 |
+
if use_npy_file:
|
91 |
+
data=np.load(objects_info_path,allow_pickle=True)
|
92 |
+
for obj in data:
|
93 |
+
if obj['label']=='object':
|
94 |
+
continue
|
95 |
+
line="name=%s,id=%d; "%(obj['label'],obj['id'])
|
96 |
+
prompt=prompt+line
|
97 |
+
else: # object info list given, used for robot demo
|
98 |
+
data=object_info_list
|
99 |
+
for obj in data:
|
100 |
+
label=obj.get('cls')
|
101 |
+
if label is None:
|
102 |
+
label=obj.get('label')
|
103 |
+
# if obj['cls']=='object':
|
104 |
+
# continue
|
105 |
+
if label in ['object','otherfurniture','other','others']:
|
106 |
+
continue
|
107 |
+
line="name=%s,id=%d; "%(label,obj['id'])
|
108 |
+
prompt=prompt+line
|
109 |
+
|
110 |
+
|
111 |
+
# get response from gpt
|
112 |
+
response,token_usage=self.call_openai(prompt)
|
113 |
+
response=response['content']
|
114 |
+
# print("response:",response)
|
115 |
+
last_line = response.splitlines()[-1] if len(response) > 0 else ''
|
116 |
+
|
117 |
+
# exract answer(list/dict) from the last line of response
|
118 |
+
# answer=self.extract_all_int_lists_from_text(last_line)
|
119 |
+
answer=self.extract_dict_from_text(last_line)
|
120 |
+
if to_print:
|
121 |
+
self.print_pretext()
|
122 |
+
print("answer:",answer)
|
123 |
+
print("\n\n")
|
124 |
+
if len(answer)==0:
|
125 |
+
answer=None
|
126 |
+
return answer,token_usage
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
# scanrefer_path="/share/data/ripl/vincenttann/sr3d/data/scanrefer/ScanRefer_filtered_sampled50.json"
|
132 |
+
scanrefer_path="/share/data/ripl/vincenttann/sr3d/data/scanrefer/ScanRefer_filtered_train_sampled1000.json"
|
133 |
+
with open(scanrefer_path, 'r') as json_file:
|
134 |
+
scanrefer_data=json.load(json_file)
|
135 |
+
|
136 |
+
from datetime import datetime
|
137 |
+
# 记录时间作为文件名
|
138 |
+
current_time = datetime.now()
|
139 |
+
formatted_time = current_time.strftime("%Y-%m-%d-%H-%M-%S")
|
140 |
+
print("formatted_time:",formatted_time)
|
141 |
+
folder_path="/share/data/ripl/vincenttann/sr3d/object_filter_dialogue/%s/"%formatted_time
|
142 |
+
os.makedirs(folder_path)
|
143 |
+
|
144 |
+
for idx,data in enumerate(scanrefer_data):
|
145 |
+
print("processing %d/%d..."%(idx+1,len(scanrefer_data)))
|
146 |
+
description=data['description']
|
147 |
+
scan_id=data['scene_id']
|
148 |
+
target_id=data['object_id']
|
149 |
+
# path="/share/data/ripl/scannet_raw/train/objects_info_gf/objects_info_gf_%s.npy"%scan_id
|
150 |
+
path="/share/data/ripl/scannet_raw/train/objects_info/objects_info_%s.npy"%scan_id
|
151 |
+
of=ObjectFilter()
|
152 |
+
of.filter_objects_by_description(path,description)
|
153 |
+
object_filter_json_name="%d_%s_%s_object_filter.json"%(idx,scan_id,target_id)
|
154 |
+
of.save_pretext(folder_path,object_filter_json_name)
|
objects_info/objects_info_scene0132_00.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e872475b1e7cc419364fbd619131214b60579dc415b61db2963f5844d41e98c8
|
3 |
+
size 17246
|
prompt_text.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# def get_principle(utterance, use_priority=False):
|
2 |
+
def get_principle(use_priority=False):
|
3 |
+
prompt = ''
|
4 |
+
prompt = prompt + "Tips: while multiple objs may appear within the description, it points to only 1 focal object, with the other objects serving to aid in locating or contextualizing it. For instance, spatial relation with other objects might be employed to establish the position or orientation of this focal object. Examples:"
|
5 |
+
prompt = prompt + "\n1.'The brown cabinet covers the entire back wall. There is a door with a blue sign located between the brown cabinet.' The first sentence is actually a noun phrase starting with 'the,' indicating that the focal object being described is 'the brown cabinet.' The second sentence describes the spatial relationship between the door and the brown cabinet, providing supplementary details about the described brown cabinet."
|
6 |
+
prompt = prompt + "\n2.'This is a big exercise ball. The ball is under the table.' The first sentence starts with 'this is,' indicating the object being described, which is a 'big exercise ball.' The second sentence is used to provide additional information about the ball's location."
|
7 |
+
prompt = prompt + "\n3.'In the corner of the kitchen, there are three trash cans. Beside the third trash can from the left, there's a white stool.' The description first sets up a scene with three trash cans, and then move on to describe the location of the white stool in relation to the trash cans. Therefore, the white stool is the target object."
|
8 |
+
if use_priority:
|
9 |
+
prompt = prompt + "\nConsider different constraints in order (1 to 7) & priority (1 highest, 7 lowest):"
|
10 |
+
prompt = prompt + "\n1: Obj name(category). Names in description & obj list may differ (e.g. similar names such as 'table' and 'desk', 'trash can' and 'recycling bin', 'coffee table' and 'end table'), so use common sense to find all possible candidate objects, ensure no missing, don't write code. If only 1 object in list has the same/similar category with the one described object, answer it directly, discard other constraints. For instance, with description 'the black bag left to the couch' and only 1 bag in the scene, answer it directly, discard 'black' and 'left' constrains."
|
11 |
+
prompt = prompt + "\n2: Horizontal relation like 'next to''farthest''closest''nearest''between''in the middle''at the center'(if given)(not include 'behind''in front of'). Consider only center x,y,z coords of objs, disregard sizes."
|
12 |
+
prompt = prompt + "\n3: Color (if given). Be lenient with color. First convert RGB to HSL. For grayscale, use lightness to compute the difference between objects' color and the specified color as a metric. For other colors, use hue instead. When computing the hue difference, be careful that it is a circular value."
|
13 |
+
prompt = prompt + "\n4: Size & shape(if given). Be cautious not to make overly absolute judgments about obj size. E.g., 'a tiny trash can' doesn't necessarily refer to smallest one in terms of volume."
|
14 |
+
prompt = prompt + "\n5: Direction relation 'left''right'(if given). To judge A on 'left' or 'right' of B, calc vec observer-A & observer-B(both projected to x-y plane). If cross product of vec observer-A & vector observer-B(in this order) has positive z, A on right of B. If z neg, A on left of B. Note that order of cross product matters, put vec observer-A at first. Consider which two objs' left-right relation needs to be determined in sentence, that is, which is A & which is B. DON'T determine left & right relation by compare x or y coords."
|
15 |
+
prompt = prompt + "\n6: Direction relation 'in front of' and 'behind'(if given). Use 'spatially closer' to replace them. To determine which object, P1 or P2, is behind Q, calculate their distances from Q. The one with the smaller distance is behind Q. It is the same for judging 'in front of': also smaller distance. DON'T determine front & behind relation by compare x or y coords."
|
16 |
+
prompt = prompt + "\n7: Vertical relation like 'above'and'under''on''sits on'(if given). Consider only center coords of objs, disregard sizes. Be more lenient with this."
|
17 |
+
prompt = prompt + "\nExplicitly go through these 7 constraints. For every constraint, if it is not mentioned in description, tell me and skip; if mentioned, apply this constraint and record the results of each candidates. For constraint 1, use common sense, no code. For others, write code, which should print the metrics of each candidate objects, instead of only print the most possible object id. After going through all constriants, evaluate all results comprehensively basing on 1-7 priority, and choose the unique target object."
|
18 |
+
else:
|
19 |
+
prompt = prompt + "\nSo first you should identify this focal object(that is, the category name of it) from the description."
|
20 |
+
prompt = prompt + "\nNext, you can identify potential objects from the object list based on the category name of the focal object. You should rely on your common sense to comprehensively identify all relevant candidates without writing code. For example, for the category name 'table,' objects such as 'table,' 'desk,' 'end table,' 'coffee table,' and so on from the object list should all be considered as potential candidates."
|
21 |
+
prompt = prompt + "\nThen, count(do not write code) and tell me the number of candidate objects. If it is 1, which means only one candidate object, you must directly choose it as answer, then stop your response. For example, if the description is 'the white bathhub on the left of the toilet' and there is only one 'bathhub'-like object in the list, answer it directly, ignore 'white' and 'left of the toilet' constraints."
|
22 |
+
prompt = prompt + "\nIf there are multiple candidate objects, you can continue. Identify the constraints in the description. There might be multiple constraints to help finding the unique target object from multiple candidate objects. For each constraint, you can define a quantitative metric to assess the degree to which each candidate object satisfies this constraint."
|
23 |
+
prompt = prompt + "\nYou can write code to calculate the metrics, printing out the metrics of each candidate objects, instead of only print the most possible object id."
|
24 |
+
prompt = prompt + "\nSome special tips for some constraints:"
|
25 |
+
|
26 |
+
prompt = prompt + "\n- Color (if given). Be lenient with color. Given the HSL value of objects, to determine black and white, use the difference of L(light) as a metric. To determine other colors, use the difference of H(hue) as a metric, and be careful H has circular value range, that H values of 360 and 0 are equal. Do not use conditions like 'if difference(color1, color2) < threshold' to determine a color."
|
27 |
+
# prompt = prompt + "\n- Color(if given). Be lenient with color, because different shades of color mentioned in description can have different RGB values. You might use RGB-space distance as a quantitative metric."
|
28 |
+
prompt = prompt + "\n- Direction relation 'left''right'(if given). To judge obj A on 'left' or 'right' of B, calc vector observer-A & observer-B(both projected to x-y plane). If cross product of vector observer-A & vector observer-B(in this order) has positive z, A on right of B. If z neg, A on left of B. Note that order of cross product matters, put vec observer-A at first. DON'T determine left & right relation by comparing x or y coords."
|
29 |
+
prompt = prompt + "\n- Direction relation 'in front of' and 'behind'(if given). Use 'spatially closer' to replace them. To determine which object, P1 or P2, is behind Q, calculate their distances from Q. The one with the smaller distance is behind Q. It is the same for judging 'in front of': also smaller distance. DON'T determine front & behind relation by comparing x or y coords."
|
30 |
+
prompt = prompt + "\n- Vertical relation such as 'on''above''under'(if given). If obj M has vertical relation with obj N, the x,y coord of ctr of M should be inside the x,y range of obj N, while z of M and z of N should satisfy the corresponding order."
|
31 |
+
prompt = prompt + "\nAfter going through all constraints in the description, double check the given description , and evaluate all results and metrics comprehensively, then choose the unique target object."
|
32 |
+
|
33 |
+
prompt = prompt + "\nPerceive wall as plane. Distance from obj to wall=vert dist to plane, not to wall center. Wall front=side of plane where obj exist."
|
34 |
+
|
35 |
+
return prompt
|
36 |
+
|
37 |
+
|
38 |
+
def get_principle_sr3d():
|
39 |
+
prompt = ""
|
40 |
+
# prompt=prompt+"\nYou must comprehensively consider the x, y, z coordinates, not only one of them (for example, you should consider both greater z and similar x, y coordinates when judging vertical relation). "
|
41 |
+
# prompt=prompt+"\nWhen determining vertical relation such as 'above''under''on''on top of''support'(if given). If obj M has vertical relation with obj N, the x,y coord of ctr of M should be inside the x,y range of obj N, while z of M and z of N should satisfy the corresponding order. You can igonre the size of objects and only consider ctr coords here."
|
42 |
+
prompt = prompt + "\nWhen determining vertical relation such as 'above''under''on''on top of''support'(if given). For example, if obj M is on top of / supportted by obj N, the x,y coord of ctr of M should be inside the x,y range of obj N, while z of M is greater than z of N. You can igonre the size in z direction of objects here. If you cannot find the obj M after several tries, you can choose one which is closest to N."
|
43 |
+
prompt = prompt + "\nWhen determining the orientation of object B relative to object A, you should calculate the angle between the x-y plane vector from A to B(projected onto x-y plane) and one of the direction vectors of A (the one that corresponds to the direction mentioned in the problem). The smaller the angle, the more it indicates that B is in the corresponding direction of A."
|
44 |
+
return prompt
|
45 |
+
|
46 |
+
|
47 |
+
def get_system_message():
|
48 |
+
system_message = "Imagine you are an artificial intelligence assitant with a python interpreter. So when answering questions, you can choose to generate python code (for example, when there is need to do quantitative evaluation). The generated code should always use print() function to print out the result and keep two decimal places for numbers. The code should be written in python, start with '```python\nimport numpy as np\nimport math\n' and end with '```'. Keep your code and comments concise. When answer step by step, stop whenever you feel there is need to generate python code (for example, where there is need to do quantitative evaluation) and wait for the result from the code execution. Make sure your code will print out something(include failure info like 'nothing found'), especially when you use if logic.\n"
|
49 |
+
# Before generating code, say 'Let's write some python code to get the results.', then stop. You'll receive an empty message from user, then you start to generate code. If you are printing thing like 'metric: value', make it clear what the metric is."
|
50 |
+
|
51 |
+
system_message+="You will receive a information list of some objects from a 3D indoor scene, which might include their center positions, sizes, colors and so on. You will also be presented with a description of one certain object in that scene, and your job is to find that object from the object list according to this description. Below are some tips to help you reasoning and finding the object:\n"
|
52 |
+
|
53 |
+
return system_message
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
numpy
|
3 |
+
plyfile
|
4 |
+
openai
|
5 |
+
tenacity
|
6 |
+
trimesh
|
scenes/scene0132_00_vh_clean_2_aligned.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc75b866ab2436dfbcd9f23da1c2e3191cdbb4ce299e5b2ce074e55a7e19af27
|
3 |
+
size 26828880
|
scenes/scene0132_00_vh_clean_2_aligned.ply
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22caa4c712bb89fe8316cbaf829bc81d1a5ee8a33a0c9f5bc20f9c5f2dded606
|
3 |
+
size 6833502
|
sources.list
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
deb http://mirrors.aliyun.com/ubuntu/ focal main restricted universe multiverse
|
2 |
+
deb http://mirrors.aliyun.com/ubuntu/ focal-security main restricted universe multiverse
|
3 |
+
deb http://mirrors.aliyun.com/ubuntu/ focal-updates main restricted universe multiverse
|
4 |
+
deb http://mirrors.aliyun.com/ubuntu/ focal-backports main restricted universe multiverse
|
transcrib3d_main.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
import ast
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from copy import deepcopy
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from tenacity import RetryError, before_sleep_log, retry, stop_after_attempt, wait_exponential_jitter # for exponential backoff
|
15 |
+
|
16 |
+
from code_interpreter import CodeInterpreter
|
17 |
+
# from config import confs_nr3d, confs_scanrefer, confs_sr3d
|
18 |
+
# from gpt_dialogue import Dialogue
|
19 |
+
# from object_filter_gpt4 import ObjectFilter
|
20 |
+
from prompt_text import get_principle, get_principle_sr3d, get_system_message
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__ + 'logger')
|
23 |
+
logger.setLevel(logging.ERROR)
|
24 |
+
console_handler = logging.StreamHandler()
|
25 |
+
console_handler.setLevel(logging.ERROR)
|
26 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
27 |
+
console_handler.setFormatter(formatter)
|
28 |
+
logger.addHandler(console_handler)
|
29 |
+
|
30 |
+
def round_list(lst, length):
|
31 |
+
# round every element in lst
|
32 |
+
for idx, num in enumerate(lst):
|
33 |
+
lst[idx] = round(num, length)
|
34 |
+
return list(lst)
|
35 |
+
|
36 |
+
def remove_spaces(s: str):
|
37 |
+
return s.replace(' ', '')
|
38 |
+
|
39 |
+
def rgb_to_hsl(rgb):
|
40 |
+
# Normalize RGB values to the range [0, 1]
|
41 |
+
r, g, b = [x / 255.0 for x in rgb]
|
42 |
+
# Calculate min and max values of RGB to find chroma
|
43 |
+
c_max = max(r, g, b)
|
44 |
+
c_min = min(r, g, b)
|
45 |
+
chroma = c_max - c_min
|
46 |
+
# Calculate lightness
|
47 |
+
lightness = (c_max + c_min) / 2
|
48 |
+
# Calculate hue and saturation
|
49 |
+
hue = 0
|
50 |
+
saturation = 0
|
51 |
+
if chroma != 0:
|
52 |
+
if c_max == r:
|
53 |
+
hue = ((g - b) / chroma) % 6
|
54 |
+
elif c_max == g:
|
55 |
+
hue = ((b - r) / chroma) + 2
|
56 |
+
elif c_max == b:
|
57 |
+
hue = ((r - g) / chroma) + 4
|
58 |
+
hue *= 60
|
59 |
+
# Calculate saturation
|
60 |
+
if lightness <= 0.5:
|
61 |
+
saturation = chroma / (2 * lightness)
|
62 |
+
else:
|
63 |
+
saturation = chroma / (2 - 2 * lightness)
|
64 |
+
return [hue, saturation, lightness]
|
65 |
+
|
66 |
+
def get_scene_center(objects):
|
67 |
+
xmin, ymin, zmin = float('inf'), float('inf'), float('inf')
|
68 |
+
xmax, ymax, zmax = float('-inf'), float('-inf'), float('-inf')
|
69 |
+
for obj in objects:
|
70 |
+
x, y, z = obj['center_position']
|
71 |
+
if x < xmin:
|
72 |
+
xmin = x
|
73 |
+
if x > xmax:
|
74 |
+
xmax = x
|
75 |
+
if y < ymin:
|
76 |
+
ymin = y
|
77 |
+
if y > ymax:
|
78 |
+
ymax = y
|
79 |
+
if z < zmin:
|
80 |
+
zmin = z
|
81 |
+
if z > zmax:
|
82 |
+
zmax = z
|
83 |
+
return round_list([(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2], 2)
|
84 |
+
|
85 |
+
def find_relevant_objects(user_instruction, scan_id):
|
86 |
+
pass
|
87 |
+
|
88 |
+
def gen_prompt(user_instruction, scan_id):
|
89 |
+
|
90 |
+
npy_path = os.path.join("objects_info", f"objects_info_{scan_id}.npy")
|
91 |
+
objects_info = np.load(npy_path, allow_pickle=True)
|
92 |
+
|
93 |
+
|
94 |
+
# objects_related = find_relevant_objects(user_instruction, scan_id)
|
95 |
+
objects_related = objects_info
|
96 |
+
|
97 |
+
|
98 |
+
# 获取场景的中心坐标
|
99 |
+
# scene_center=get_scene_center(objects_related)
|
100 |
+
scene_center = get_scene_center(objects_info) # 注意这里应该用所有物体的信息,而不只是relevant
|
101 |
+
# 生成prompt中的背景信息部分
|
102 |
+
prompt = scan_id + ":objects with quantitative description based on right-hand Cartesian coordinate system with x-y-z axes, x-y plane=ground, z-axis=up/down. Coords format [x, y, z].\n\n"
|
103 |
+
# if dataset == 'nr3d':
|
104 |
+
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
|
105 |
+
# elif dataset == 'scanrefer':
|
106 |
+
# if use_camera_position:
|
107 |
+
# prompt = prompt + "Scene center:%s.\n" % remove_spaces(str(scene_center))
|
108 |
+
# prompt = prompt + "Observer position:%s.\n" % remove_spaces(str(round_list(camera_info_aligned['position'], 2)))
|
109 |
+
# else:
|
110 |
+
# prompt = prompt + "Scene center:%s. If no direction vector, observer at center for objorientation.\n" % remove_spaces(str(scene_center))
|
111 |
+
prompt = prompt + "Scene center:%s. If no direction vector, observer at center for obj orientation.\n\n" % remove_spaces(str(scene_center))
|
112 |
+
prompt = prompt + "objs list:\n"
|
113 |
+
lines = []
|
114 |
+
# 生成prompt中对物体的定量描述部分(遍历所有相关物体)
|
115 |
+
for obj in objects_related:
|
116 |
+
# 位置信息,保留2位小数
|
117 |
+
center_position = obj['center_position']
|
118 |
+
center_position = round_list(center_position, 2)
|
119 |
+
# size信息,保留2位小数
|
120 |
+
size = obj['size']
|
121 |
+
size = round_list(size, 2)
|
122 |
+
# extension信息,保留2位小数
|
123 |
+
extension = obj['extension']
|
124 |
+
extension = round_list(extension, 2)
|
125 |
+
# 方向信息,用方向向量表示. 注意,scanrefer由于用的不是scannet原始obj id,所以不能用方向信息
|
126 |
+
if obj['has_front']:
|
127 |
+
front_point = np.array(obj['front_point'])
|
128 |
+
center = np.array(obj['obb'][0:3])
|
129 |
+
direction_vector = front_point - center
|
130 |
+
direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
|
131 |
+
# 再计算左和右的方向向量,全部保留两位小数
|
132 |
+
front_vector = round_list(direction_vector_normalized, 2)
|
133 |
+
up_vector = np.array([0, 0, 1])
|
134 |
+
left_vector = round_list(np.cross(direction_vector_normalized, up_vector), 2)
|
135 |
+
right_vector = round_list(np.cross(up_vector, direction_vector_normalized), 2)
|
136 |
+
behind_vector = round_list(-np.array(front_vector), 2)
|
137 |
+
# 生成方向信息
|
138 |
+
direction_info = ";direction vectors:front=%s,left=%s,right=%s,behind=%s\n" %(front_vector, left_vector, right_vector, behind_vector)
|
139 |
+
#
|
140 |
+
else:
|
141 |
+
direction_info = "\n" # 未知方向向量就啥都不写
|
142 |
+
|
143 |
+
# sr3d,给出center、size
|
144 |
+
# if dataset == 'sr3d':
|
145 |
+
if False:
|
146 |
+
line = f'{obj["label"]},id={obj["id"]},ctr={remove_spaces(str(center_position))},size={remove_spaces(str(size))}'
|
147 |
+
# nr3d和scanrefer,给出center、size、color
|
148 |
+
else:
|
149 |
+
rgb = obj['avg_rgba'][0:3]
|
150 |
+
hsl = round_list(rgb_to_hsl(rgb), 2)
|
151 |
+
# line="%s,id=%s,ctr=%s,size=%s,RGB=%s" %(obj['label'], obj['id'], self.remove_space(str(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(rgb) )) 原版rgb
|
152 |
+
line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(obj['label'], obj['id'], remove_spaces(str(center_position)), remove_spaces(str(size)), remove_spaces(str(hsl)))#rgb换成hsl
|
153 |
+
# line = "%s(relevant to %s),id=%s,ctr=%s,size=%s,HSL=%s" % (obj['label'],id_to_name_in_description[obj['id']], obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl))) # 格式:name=原名称(description里的名称)
|
154 |
+
# if id_to_name_in_description[obj['id']]=='room':
|
155 |
+
# name=obj['label']
|
156 |
+
# else:
|
157 |
+
# name=id_to_name_in_description[obj['id']]
|
158 |
+
# line="%s,id=%s,ctr=%s,size=%s,HSL=%s" %(name, obj['id'], self.remove_spaces(st(center_position)), self.remove_spaces(str(size)), self.remove_spaces(str(hsl) )) # 式:name=description里的名称
|
159 |
+
lines.append(line + direction_info)
|
160 |
+
# if self.obj_info_ablation_type == 4:
|
161 |
+
# random.seed(0)
|
162 |
+
# random.shuffle(lines)
|
163 |
+
prompt += ''.join(lines)
|
164 |
+
# prompt中的要求
|
165 |
+
line = "\nInstruction:find the one described object in description: \n\"%s\"\n" % user_instruction
|
166 |
+
prompt = prompt + line
|
167 |
+
|
168 |
+
prompt = prompt + "\n\nThere is exactly one answer, so if you receive multiple answers, considerother constraints; if get no answers, loosen constraints."
|
169 |
+
prompt = prompt + "\n\nWork this out step by step to ensure right answer."
|
170 |
+
prompt = prompt + "\n\nIf the answer is complete, add \"Now the answer is complete -- {'ID':id}\" to the end of your answer(that is, your completion, not your code), where id is the id of the referred obj. Do not add anything after."
|
171 |
+
|
172 |
+
return prompt
|
173 |
+
|
174 |
+
|
175 |
+
@retry(wait=wait_exponential_jitter(initial=20, max=120, jitter=20), stop=stop_after_attempt(5), before_sleep=before_sleep_log(logger, logging.ERROR)) # 20s,40s,80s,120s + random.uniform(0,20)
|
176 |
+
def get_gpt_response(prompt: str, code_interpreter: CodeInterpreter):
|
177 |
+
print("llm_name:",code_interpreter.model)
|
178 |
+
# get response from GPT(using code interpreter). using retry from tenacity.
|
179 |
+
# count the token usage and time as well
|
180 |
+
# if the reponse does not include "Now the answer is complete", this means the answer is notdone. attach an empty user message to let GPT to keep going.
|
181 |
+
# start timing
|
182 |
+
call_start_time = time.time()
|
183 |
+
# the first call with the original prompt
|
184 |
+
response, token_usage_total = code_interpreter.call_openai_with_code_interpreter(prompt)
|
185 |
+
response = response['content']
|
186 |
+
# loop until "Now the answer is complete" is in the response, or looping more than 10 times.
|
187 |
+
count_response = 0
|
188 |
+
while not "Now the answer is complete" in response:
|
189 |
+
if count_response >= 10:
|
190 |
+
print("Response does not end with 'Now the answer is complete.' !")
|
191 |
+
break
|
192 |
+
response, token_usage_add = code_interpreter.call_openai_with_code_interpreter('')
|
193 |
+
response = response['content']
|
194 |
+
token_usage_total += token_usage_add
|
195 |
+
count_response += 1
|
196 |
+
print("count_response:", count_response)
|
197 |
+
# stop timing
|
198 |
+
call_end_time = time.time()
|
199 |
+
time_consumed = call_end_time - call_start_time
|
200 |
+
# self.token_usage_this_ques += token_usage_total
|
201 |
+
# self.token_usage_whole_run += token_usage_total
|
202 |
+
# self.time_consumed_this_ques += time_consumed
|
203 |
+
# self.time_consumed_whole_run += time_consumed
|
204 |
+
# print("\n*** Refer model: token usage=%d, time consumed=%ds, TPM=%.2f ***" %(token_usage_total, time_consumed, token_usage_total / time_consumed * 60))
|
205 |
+
return response
|
206 |
+
|
207 |
+
def extract_answer_id_from_last_line(last_line, random_choice_list=[0,]):
|
208 |
+
# 如果没有按照预期格式回复则随机选取(Sr3d)或直接选成0(Nr3d和Scanrefer);按预期格式恢复则提取答案
|
209 |
+
wrong_return_format = False
|
210 |
+
last_line_split = last_line.split('--')
|
211 |
+
# 使用正则表达式从字符串中提取字典部分
|
212 |
+
pattern = r"\{[^\}]*\}"
|
213 |
+
match = re.search(pattern, last_line_split[-1])
|
214 |
+
if match:
|
215 |
+
# 获取匹配的字典字符串
|
216 |
+
matched_dict_str = match.group()
|
217 |
+
try:
|
218 |
+
# 解析字典字符串为字典对象
|
219 |
+
extracted_dict = ast.literal_eval(matched_dict_str)
|
220 |
+
print(extracted_dict)
|
221 |
+
answer_id = extracted_dict['ID']
|
222 |
+
# 如果确实以 Now the answer is complete -- {'ID': xxx} 的格式回复了,但是xxx不是数字(例如是None),也能随机选。
|
223 |
+
if not isinstance(answer_id, int):
|
224 |
+
if isinstance(answer_id, list) and all([isinstance(e, int) for e in answer_id]):
|
225 |
+
print("Wrong answer format: %s. random choice from this list" % str(answer_id))
|
226 |
+
answer_id = random.choice(answer_id)
|
227 |
+
else:
|
228 |
+
print("Wrong answer format: %s. No dict found. Random choice from relevant objects." % str(answer_id))
|
229 |
+
answer_id = random.choice(random_choice_list)
|
230 |
+
wrong_return_format = True
|
231 |
+
except BaseException:
|
232 |
+
print("Wrong answer format!! No dict found. Random choice.")
|
233 |
+
answer_id = random.choice(random_choice_list)
|
234 |
+
wrong_return_format = True
|
235 |
+
else:
|
236 |
+
print("Wrong answer format!! No dict found. Random choice.")
|
237 |
+
answer_id = random.choice(random_choice_list)
|
238 |
+
wrong_return_format = True
|
239 |
+
return answer_id, wrong_return_format
|
240 |
+
|
241 |
+
def get_openai_config(llm_name='gpt-3.5-turbo-0125'):
|
242 |
+
system_message = ""
|
243 |
+
system_message += get_system_message()
|
244 |
+
system_message += get_principle()
|
245 |
+
openai_config = {
|
246 |
+
# 'model': 'gpt-4-turbo-preview',
|
247 |
+
'model': llm_name,
|
248 |
+
'temperature': 1e-7,
|
249 |
+
'top_p': 1e-7,
|
250 |
+
# 'max_tokens': 4096,
|
251 |
+
'max_tokens': 8192,
|
252 |
+
'system_message': system_message,
|
253 |
+
# 'load_path': '',
|
254 |
+
'save_path': 'chats',
|
255 |
+
'debug': True
|
256 |
+
}
|
257 |
+
return openai_config
|
258 |
+
|
259 |
+
if __name__ == "__main__":
|
260 |
+
|
261 |
+
# system_message = 'Imagine you are an artificial intelligence assistant. You job is to do 3D referring reasoning, namely to find the object for a given utterance from a 3d scene presented as object-centric semantic information.\n'
|
262 |
+
system_message = ""
|
263 |
+
system_message += get_system_message()
|
264 |
+
system_message += get_principle()
|
265 |
+
openai_config = {
|
266 |
+
'model': 'gpt-4',
|
267 |
+
'temperature': 1e-7,
|
268 |
+
'top_p': 1e-7,
|
269 |
+
# 'max_tokens': 4096,
|
270 |
+
'max_tokens': 8192,
|
271 |
+
'system_message': system_message,
|
272 |
+
# 'load_path': '',
|
273 |
+
'save_path': 'chats',
|
274 |
+
'debug': True
|
275 |
+
}
|
276 |
+
code_interpreter = CodeInterpreter(**openai_config)
|
277 |
+
prompt = gen_prompt("Find the chair next to the table.", "scene0132_00")
|
278 |
+
print(prompt)
|
279 |
+
|
280 |
+
response = get_gpt_response(prompt, code_interpreter)
|
281 |
+
# print(response)
|
282 |
+
print("-------pretext--------")
|
283 |
+
print(code_interpreter.pretext)
|
284 |
+
|
285 |
+
|