Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import gradio as gr | |
from time import sleep | |
from signal import SIGTERM | |
from psutil import process_iter | |
from settings import GRAND3D_Settings | |
from utils import list_dirs | |
import open3d as o3d | |
from copy import deepcopy | |
import numpy as np | |
import re | |
from bs4 import BeautifulSoup | |
import logging | |
# The following line sets the root logger level as well. | |
# It's equivalent to both previous statements combined: | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
from session import Session | |
from model import load_model_and_dataloader, get_model_response | |
# Load model and tokenizer once at the start | |
model_path = "checkpoints/lora_grounded_obj_ref_checkpoint-4896" | |
model_base = "checkpoints/llava-llama-2-7b-chat-lightning-preview" | |
load_8bit = False | |
load_4bit = False | |
load_bf16 = True | |
scene_to_obj_mapping = "data/predicted_scene_data_update_5.json" | |
# scene_to_obj_mapping = "data/scanrefer_ground_truth_scene_graph.json" | |
max_new_tokens = 5000 | |
obj_context_feature_type = "text" | |
tokenizer, model, data_loader = load_model_and_dataloader( | |
model_path=model_path, | |
model_base=model_base, | |
load_8bit=load_8bit, | |
load_4bit=load_4bit, | |
load_bf16=load_bf16, | |
scene_to_obj_mapping=scene_to_obj_mapping, | |
) | |
def get_chatbot_response(user_chat_input, scene_id): | |
# Get the response from the model | |
prompt, response = get_model_response( | |
model=model, | |
tokenizer=tokenizer, | |
data_loader=data_loader, | |
scene_id=scene_id, | |
user_input=user_chat_input, | |
max_new_tokens=max_new_tokens, | |
temperature=0.2, | |
top_p=0.9 | |
) | |
return scene_id, prompt, response | |
# def get_chatbot_response(user_chat_input): | |
# # Get the response from the chatbot | |
# scene_id = "scene0643_00" | |
# scene_graph = """ | |
# Object-centric context: <obj_0>: {'category': 'door', 'centroid': '[0.35, 1.99, 1.11]', 'extent': '[0.68, 0.65, 2.11]'}; <obj_1>: {'category': 'ceiling', 'centroid': '[1.04, -1.39, 2.68]', 'extent': '[0.18, 0.90, 0.05]'}; <obj_2>: {'category': 'ceiling', 'centroid': '[0.77, 2.09, 2.65]', 'extent': '[0.94, 0.86, 0.11]'}; <obj_3>: {'category': 'trash can', 'centroid': '[-0.61, -2.16, 0.21]', 'extent': '[0.42, 0.36, 0.41]'}; <obj_4>: {'category': 'chair', 'centroid': '[0.35, -1.35, 0.50]', 'extent': '[0.46, 0.47, 0.94]'}; <obj_5>: {'category': 'trash can', 'centroid': '[-0.22, -2.13, 0.24]', 'extent': '[0.40, 0.28, 0.39]'}; <obj_6>: {'category': 'cabinet', 'centroid': '[-1.24, 0.00, 0.58]', 'extent': '[0.61, 0.57, 0.79]'}; <obj_7>: {'category': 'cup', 'centroid': '[0.62, 0.23, 0.77]', 'extent': '[0.14, 0.14, 0.08]'}; <obj_8>: {'category': 'window', 'centroid': '[-0.35, -2.87, 1.13]', 'extent': '[2.05, 0.60, 1.07]'}; <obj_9>: {'category': 'file cabinet', 'centroid': '[0.40, -1.97, 0.39]', 'extent': '[0.40, 0.66, 0.73]'}; <obj_10>: {'category': 'monitor', 'centroid': '[0.92, -1.51, 0.97]', 'extent': '[0.25, 0.57, 0.47]'}; <obj_11>: {'category': 'chair', 'centroid': '[0.34, 0.59, 0.43]', 'extent': '[0.65, 0.64, 0.94]'}; <obj_12>: {'category': 'desk', 'centroid': '[0.64, 0.75, 0.57]', 'extent': '[0.76, 1.60, 0.82]'}; <obj_13>: {'category': 'chair', 'centroid': '[0.55, -0.33, 0.48]', 'extent': '[0.60, 0.60, 0.87]'}; <obj_14>: {'category': 'office chair', 'centroid': '[-0.28, 1.56, 0.46]', 'extent': '[0.67, 0.55, 1.02]'}; <obj_15>: {'category': 'office chair', 'centroid': '[-0.86, -1.53, 0.43]', 'extent': '[0.54, 0.64, 0.97]'}; <obj_16>: {'category': 'chair', 'centroid': '[-0.28, 1.56, 0.46]', 'extent': '[0.67, 0.55, 1.02]'}; <obj_17>: {'category': 'monitor', 'centroid': '[0.98, 0.56, 1.05]', 'extent': '[0.21, 0.60, 0.54]'}; <obj_18>: {'category': 'doorframe', 'centroid': '[-0.17, 2.42, 1.01]', 'extent': '[0.16, 0.18, 1.70]'}; <obj_19>: {'category': 'chair', 'centroid': '[-0.86, -1.53, 0.43]', 'extent': '[0.54, 0.64, 0.97]'}; <obj_20>: {'category': 'bookshelf', 'centroid': '[0.93, 2.00, 1.34]', 'extent': '[0.73, 0.99, 2.60]'}; <obj_21>: {'category': 'office chair', 'centroid': '[0.35, -1.35, 0.50]', 'extent': '[0.46, 0.47, 0.94]'}; <obj_22>: {'category': 'desk', 'centroid': '[-1.23, 1.60, 0.70]', 'extent': '[0.80, 2.01, 0.51]'}; <obj_23>: {'category': 'book', 'centroid': '[0.91, 1.31, 0.89]', 'extent': '[0.34, 0.32, 0.30]'}; <obj_24>: {'category': 'desk', 'centroid': '[-1.24, -1.12, 0.54]', 'extent': '[0.79, 1.88, 0.85]'}; <obj_25>: {'category': 'desk', 'centroid': '[0.63, -1.51, 0.53]', 'extent': '[0.81, 1.97, 0.85]'}; <obj_26>: {'category': 'calendar', 'centroid': '[-1.72, -0.44, 1.40]', 'extent': '[0.07, 0.88, 0.83]'}; <obj_27>: {'category': 'office chair', 'centroid': '[0.34, 0.59, 0.43]', 'extent': '[0.65, 0.64, 0.94]'}; <obj_28>: {'category': 'file cabinet', 'centroid': '[-1.02, -0.76, 0.47]', 'extent': '[0.58, 0.75, 0.81]'}; <obj_29>: {'category': 'cup', 'centroid': '[-1.26, -1.65, 0.78]', 'extent': '[0.10, 0.12, 0.04]'}; <obj_30>: {'category': 'keyboard', 'centroid': '[0.55, 0.84, 0.73]', 'extent': '[0.22, 0.15, 0.03]'} | |
# """ | |
# response = """ | |
# <detailed_grounding>a <p>brown wooden office desk</p>[<obj_12>] on the left to the <p>gray shelf</p>[<obj_20>].</detailed_grounding> <refer_expression_grounding>These sentences refer to <p>the brown wooden office desk</p>[<obj_12>].</refer_expression_grounding> | |
# """ | |
# return scene_id, scene_graph, response | |
# Resetting to blank | |
def reset_textbox(): | |
return gr.update(value="") | |
# to set a component as visible=False | |
def set_visible_false(): | |
return gr.update(visible=False) | |
# to set a component as visible=True | |
def set_visible_true(): | |
return gr.update(visible=True) | |
def change_scene_or_system_prompt(dropdown_scene_selection: str): | |
# reset model_3d, chatbot_for_display, chat_counter, server_status_code | |
new_session_state = Session.create_for_scene(dropdown_scene_selection) | |
file_name = f"{dropdown_scene_selection}.obj" | |
print(os.path.join(GRAND3D_Settings.data_path, dropdown_scene_selection, file_name)) | |
return ( | |
new_session_state, | |
os.path.join(GRAND3D_Settings.data_path, dropdown_scene_selection, file_name), | |
None, | |
new_session_state.chat_history_for_display, | |
) | |
def cylinder_frame(p0, p1): | |
"""Calculate the transformation matrix to position a unit cylinder between two points.""" | |
direction = np.asarray(p1) - np.asarray(p0) | |
length = np.linalg.norm(direction) | |
direction /= length | |
# Computing rotation matrix using Rodrigues' formula | |
rot_axis = np.cross([0, 0, 1], direction) | |
rot_angle = np.arccos(np.dot([0, 0, 1], direction)) | |
rot_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(rot_axis * rot_angle) | |
# Translation | |
translation = (np.asarray(p0) + np.asarray(p1)) / 2 | |
transformation = np.eye(4) | |
transformation[:3, :3] = rot_matrix | |
transformation[:3, 3] = translation | |
scaling = np.eye(4) | |
scaling[2, 2] = length | |
transformation = np.matmul(transformation, scaling) | |
return transformation | |
def create_cylinder_mesh(p0, p1, color, radius=0.02, resolution=20, split=1): | |
"""Create a colored cylinder mesh between two points p0 and p1.""" | |
cylinder = o3d.geometry.TriangleMesh.create_cylinder( | |
radius=radius, height=1, resolution=resolution, split=split | |
) | |
transformation = cylinder_frame(p0, p1) | |
cylinder.transform(transformation) | |
# Apply color | |
cylinder.paint_uniform_color(color) | |
return cylinder | |
def prettify_mesh_for_gradio(mesh): | |
# Define the transformation matrix | |
T = np.array([[0, -1, 0, 0], [0, 0, 1, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) | |
# Apply the transformation | |
mesh.transform(T) | |
mesh.scale(10.0, center=mesh.get_center()) | |
bright_factor = 1 # Adjust this factor to get the desired brightness | |
mesh.vertex_colors = o3d.utility.Vector3dVector( | |
np.clip(np.asarray(mesh.vertex_colors) * bright_factor, 0, 1) | |
) | |
return mesh | |
def create_bbox(center, extents, color=[1, 0, 0], radius=0.02): | |
"""Create a colored bounding box with given center, extents, and line thickness.""" | |
# ... [The same code as before to define corners and lines] ... | |
print(extents) | |
print(type(extents)) | |
extents = extents.replace("[", "").replace("]", "") | |
center = center.replace("[", "").replace("]", "") | |
extents = [float(x.strip()) for x in extents.split(",")] | |
center = [float(x.strip()) for x in center.split(",")] | |
sx, sy, sz = float(extents[0]), float(extents[1]), float(extents[2]) | |
x_corners = [sx / 2, sx / 2, -sx / 2, -sx / 2, sx / 2, sx / 2, -sx / 2, -sx / 2] | |
y_corners = [sy / 2, -sy / 2, -sy / 2, sy / 2, sy / 2, -sy / 2, -sy / 2, sy / 2] | |
z_corners = [sz / 2, sz / 2, sz / 2, sz / 2, -sz / 2, -sz / 2, -sz / 2, -sz / 2] | |
corners_3d = np.vstack([x_corners, y_corners, z_corners]) | |
corners_3d[0, :] = corners_3d[0, :] + float(center[0]) | |
corners_3d[1, :] = corners_3d[1, :] + float(center[1]) | |
corners_3d[2, :] = corners_3d[2, :] + float(center[2]) | |
corners_3d = np.transpose(corners_3d) | |
lines = [ | |
[0, 1], | |
[1, 2], | |
[2, 3], | |
[3, 0], | |
[4, 5], | |
[5, 6], | |
[6, 7], | |
[7, 4], | |
[0, 4], | |
[1, 5], | |
[2, 6], | |
[3, 7], | |
] | |
cylinders = [] | |
for line in lines: | |
p0, p1 = corners_3d[line[0]], corners_3d[line[1]] | |
cylinders.append(create_cylinder_mesh(p0, p1, color, radius)) | |
return cylinders | |
def highlight_clusters_in_mesh( | |
centroids_extents_detailed, | |
centroids_extends_refer, | |
mesh, | |
output_dir, | |
output_file_name="highlighted_mesh.glb", | |
): | |
print("*" * 50) | |
# Visualize the highlighted points by drawing 3D bounding boxes overlay on a mesh | |
old_mesh = deepcopy(mesh) | |
output_path = os.path.join(output_dir, "mesh_vis") | |
if not os.path.exists(output_path): | |
os.makedirs(output_path) | |
# Create a combined mesh to hold both the original and the bounding boxes | |
combined_mesh = o3d.geometry.TriangleMesh() | |
combined_mesh += old_mesh | |
# Draw bounding boxes for each centroid and extent | |
for center, extent in centroids_extents_detailed: | |
print("center: ", center) | |
print("extent: ", extent) | |
bbox = create_bbox(center, extent, color=[0, 0, 1]) # Red color for all boxes | |
for b in bbox: | |
combined_mesh += b | |
for center, extent in centroids_extends_refer: | |
bbox = create_bbox(center, extent, color=[0, 1, 0]) | |
for b in bbox: | |
combined_mesh += b | |
combined_mesh = prettify_mesh_for_gradio(combined_mesh) | |
# Save the combined mesh | |
output_file_path = os.path.join(output_path, output_file_name) | |
o3d.io.write_triangle_mesh( | |
output_file_path, combined_mesh, write_vertex_colors=True | |
) | |
print("*" * 50) | |
return output_file_path | |
def extract_objects(text): | |
return re.findall(r"<obj_\d+>", text) | |
# Parse the scene graph into a dictionary | |
def parse_scene_graph(scene_graph): | |
scene_dict = {} | |
matches = re.findall(r"<obj_(\d+)>: (\{.*?\})", scene_graph) | |
for match in matches: | |
obj_id = f"<obj_{match[0]}>" | |
obj_data = eval(match[1]) | |
scene_dict[obj_id] = obj_data | |
return scene_dict | |
def get_centroids_extents(obj_list, scene_dict): | |
centroids_extents = [] | |
for obj in obj_list: | |
if obj in scene_dict: | |
centroid = scene_dict[obj]["centroid"] | |
extent = scene_dict[obj]["extent"] | |
centroids_extents.append((centroid, extent)) | |
return centroids_extents | |
def language_model_forward( | |
session_state, user_chat_input, top_p, temperature, dropdown_scene | |
): | |
session_state = Session.create_for_scene(dropdown_scene) | |
session_state.chat_history_for_display.append( | |
(user_chat_input, None) | |
) # append in a tuple format, first is user input, second is assistant response | |
yield session_state, None, session_state.chat_history_for_display | |
# Load in a 3D model | |
file_name = f"{session_state.scene}.obj" | |
original_model_path = os.path.join( | |
GRAND3D_Settings.data_path, session_state.scene, file_name | |
) | |
print("original_model_path: ", original_model_path) | |
# Load the GLB mesh | |
mesh = o3d.io.read_triangle_mesh(original_model_path) | |
# get chatbot response | |
scene_id, scene_graph, response = get_chatbot_response(user_chat_input, session_state.scene) | |
assert scene_id == session_state.scene # Ensure the scene ID matches | |
# use scene_graph and response to get centroids and extents | |
# Parse the scene graph into a dictionary | |
scene_dict = parse_scene_graph(scene_graph) | |
print("Model Input: " + str(scene_dict)) | |
print("=" * 50) | |
print("Model Response: " + response) | |
# Parse the response to get detailed and refer expression groundings | |
soup = BeautifulSoup(response, "html.parser") | |
detailed_grounding_html = str(soup.find("detailed_grounding")) | |
refer_expression_grounding_html = str(soup.find("refer_expression_grounding")) | |
# Extract objects from both sections | |
detailed_objects = extract_objects(detailed_grounding_html) | |
refer_objects = extract_objects(refer_expression_grounding_html) | |
# Extract objects from both sections | |
print("detailed_objects: ", detailed_objects) | |
print("refer_objects: ", refer_objects) | |
# Perform set subtraction to get remaining objects | |
remaining_objects = list(set(detailed_objects) - set(refer_objects)) | |
print("remaining_objects: ", remaining_objects) | |
centroids_extents_detailed = get_centroids_extents(remaining_objects, scene_dict) | |
print("centroids_extents_detailed: ", centroids_extents_detailed) | |
centroids_extents_refer = get_centroids_extents(refer_objects, scene_dict) | |
print("centroids_extents_refer: ", centroids_extents_refer) | |
# Define your centroids and extents here (example data) | |
# Highlight clusters in the mesh and save it | |
session_output_dir = session_state.get_session_output_dir() | |
highlighted_model_path = highlight_clusters_in_mesh( | |
centroids_extents_detailed, | |
centroids_extents_refer, | |
mesh, | |
session_output_dir, | |
output_file_name="highlighted_model.glb", | |
) | |
# Update the chat history with the response | |
last_turn = session_state.chat_history_for_display[-1] # first is user input, second is assistant response | |
last_turn = (last_turn[0], response) | |
session_state.chat_history_for_display[-1] = last_turn | |
session_state.save() # save the session state | |
yield session_state, highlighted_model_path, session_state.chat_history_for_display | |
title = """<h1 align="center">π€ 3D-GRAND: Towards Better Grounding and Less Hallucination for 3D-LLMs π</h1> | |
<p><center> | |
<a href="https://3d-grand.github.io/" target="_blank">[Project Page]</a> | |
<a href="https://www.dropbox.com/scl/fo/5p9nb4kalnz407sbqgemg/AG1KcxeIS_SUoJ1hoLPzv84?rlkey=weunabtbiz17jitfv3f4jpmm1&dl=0" target="_blank">[3D-GRAND Data]</a> | |
<a href="https://www.dropbox.com/scl/fo/inemjtgqt2nkckymn65rp/AGi2KSYU9AHbnpuj7TWYihs?rlkey=ldbn36b1z6nqj74yv5ph6cqwc&dl=0" target="_blank">[3D-POPE Data]</a> | |
</center></p> | |
""" | |
# Modifying existing Gradio Theme | |
# theme = gr.themes.Soft( | |
# primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.pink | |
# ) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
session_state = gr.State(Session.create) | |
gr.HTML(title) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
dropdown_scene = gr.Dropdown( | |
choices=list_dirs(GRAND3D_Settings.data_path), | |
value=GRAND3D_Settings.default_scene, | |
interactive=True, | |
label="Select a scene", | |
) | |
model_3d = gr.Model3D( | |
value=os.path.join( | |
GRAND3D_Settings.data_path, | |
GRAND3D_Settings.default_scene, | |
f"{GRAND3D_Settings.default_scene}.obj", | |
), | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
label="3D Model", | |
camera_position=(-50, 65, 10), | |
zoom_speed=10.0, | |
) | |
gr.HTML( | |
"""<center><strong> | |
π SCROLL or DRAG on the 3D Model | |
to zoom in/out and rotate. Press CTRL and DRAG to pan. | |
</strong></center> | |
""" | |
) | |
gr.HTML( | |
"""<center><strong> | |
π When grounding finishes, | |
the grounding result will be displayed below. | |
</strong></center> | |
""" | |
) | |
model_3d_grounding_result = gr.Model3D( | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
label="Grounding Result", | |
zoom_speed=15.0, | |
) | |
gr.HTML( | |
"""<center><strong> | |
<div style="display:inline-block; color:blue">■</div> = Landmark | |
<div style="display:inline-block; color:green">■</div> = Chosen Target | |
</strong></center> | |
""" | |
) | |
with gr.Column(scale=5): | |
chat_history_for_display = gr.Chatbot( | |
value=[(None, GRAND3D_Settings.INITIAL_MSG_FOR_DISPLAY)], | |
label="Chat Assistant", | |
height=510, | |
render_markdown=False, | |
sanitize_html=False, | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
user_chat_input = gr.Textbox( | |
placeholder="I want to find the chair near the table", | |
show_label=False, | |
) | |
with gr.Column(scale=1, min_width=0): | |
send_button = gr.Button("Send", variant="primary") | |
with gr.Column(scale=1, min_width=0): | |
clear_button = gr.Button("Clear") | |
with gr.Row(): | |
with gr.Accordion(label="Examples for user message:", open=True): | |
gr.Examples( | |
examples=[ | |
["The TV on the drawer, opposing the bed."], | |
["the desk next to the window"] | |
], | |
inputs=user_chat_input, | |
) | |
with gr.Accordion("Parameters", open=False, visible=False): | |
top_p = gr.Slider( | |
minimum=0, | |
maximum=1.0, | |
value=1.0, | |
step=0.05, | |
interactive=True, | |
label="Top-p (nucleus sampling)", | |
) | |
temperature = gr.Slider( | |
minimum=0, | |
maximum=5.0, | |
value=1.0, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
# gr.Markdown("### Terms of Service") | |
# gr.HTML( | |
# """By using this service, users are required to agree to the following terms: | |
# The service is a research preview intended for non-commercial use only. | |
# The service may collect user dialogue data for future research.""" | |
# ) | |
# Event handling | |
dropdown_scene.change( | |
fn=change_scene_or_system_prompt, | |
inputs=[dropdown_scene], | |
outputs=[session_state, model_3d, model_3d_grounding_result, chat_history_for_display], | |
) | |
clear_button.click( | |
fn=change_scene_or_system_prompt, | |
inputs=[dropdown_scene], | |
outputs=[session_state, model_3d, model_3d_grounding_result, chat_history_for_display], | |
) | |
user_chat_input.submit( | |
fn=language_model_forward, | |
inputs=[session_state, user_chat_input, top_p, temperature, dropdown_scene], | |
outputs=[session_state, model_3d_grounding_result, chat_history_for_display], | |
) | |
send_button.click( | |
fn=language_model_forward, | |
inputs=[session_state, user_chat_input, top_p, temperature, dropdown_scene], | |
outputs=[session_state, model_3d_grounding_result, chat_history_for_display], | |
) | |
send_button.click(reset_textbox, [], [user_chat_input]) | |
user_chat_input.submit(reset_textbox, [], [user_chat_input]) | |
sleep_time = 2 | |
port = 7011 | |
for x in range(1, 10): # try 8 times | |
try: | |
# put your logic here | |
gr.close_all() | |
demo.queue( | |
max_size=20, | |
).launch( | |
# debug=True, | |
# server_name="0.0.0.0", | |
# server_port=port, | |
# share=True | |
) | |
except OSError: | |
for proc in process_iter(): | |
for conns in proc.connections(kind="inet"): | |
if conns.laddr.port == port: | |
proc.send_signal(SIGTERM) # or SIGKILL | |
print(f"Retrying {x} time...") | |
pass | |
sleep(sleep_time) # wait for 2 seconds before trying to fetch the data again | |
sleep_time *= 2 # exponential backoff | |