|
from autogen import GroupChatManager |
|
import json |
|
import re, os |
|
import networkx as nx |
|
|
|
from agents import create_parse_agents, create_graph_agents, language_summary_agents, calculation_summary_agents |
|
from agents import is_termination_msg, is_termination_require, gpt4_config |
|
from corrector_agents import get_corrector_agents |
|
from refiner_agents import get_refiner_agents |
|
|
|
from chats import InputParserGroupChat, RequirementGroupChat, LanguageGroupChat, CalculationGroupChat, SceneGraphGroupChat, SchemaGroupChat, LayoutCorrectorGroupChat, ObjectDeletionGroupChat, LayoutRefinerGroupChat |
|
|
|
from utils import get_room_priors, extract_list_from_json |
|
from utils import preprocess_scene_graph, build_graph, remove_unnecessary_edges, handle_under_prepositions, get_conflicts, get_size_conflicts, get_object_from_scene_graph |
|
from utils import get_object_from_scene_graph, get_rotation, get_cluster_objects, clean_and_extract_edges |
|
from utils import get_cluster_size |
|
from utils import get_possible_positions, is_point_bbox, calculate_overlap, get_topological_ordering, place_object, get_depth, get_visualization |
|
import openshape |
|
import torch |
|
import numpy as np |
|
import transformers |
|
import threading |
|
import multiprocessing |
|
import sys, shutil |
|
import pandas as pd |
|
from torch.nn import functional as F |
|
import objaverse |
|
import trimesh |
|
import certifi |
|
import ssl |
|
|
|
ssl._create_default_https_context = ssl._create_unverified_context |
|
os.environ['SSL_CERT_FILE'] = certifi.where() |
|
|
|
class Generator: |
|
def __init__(self, layout_elements=['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'], room_dimensions=[5.0, 5.0, 3.0], result_file="./results/layout_w_cot.json"): |
|
|
|
self.room_dimensions = room_dimensions |
|
self.room_priors = get_room_priors(self.room_dimensions) |
|
|
|
self.layout_elements = list(layout_elements) |
|
self.result_file = result_file |
|
self.scene_graph = None |
|
self.cot_info = {} |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
meta = json.load( |
|
open('./embeddings/objaverse_meta.json') |
|
) |
|
self.meta = {x['u']: x for x in meta['entries']} |
|
|
|
deser = torch.load('./embeddings/objaverse.pt') |
|
self.us = deser['us'] |
|
self.feats = deser['feats'] |
|
|
|
local_assets = pd.read_excel("./assets/copy.xlsx", skiprows=2) |
|
captions = local_assets["caption_clip"].tolist() |
|
|
|
file_paths = [] |
|
bbx_values = [] |
|
for index, row in local_assets.iterrows(): |
|
model_name = row['name_en'] |
|
model_path = os.path.join("./assets/lvm_2032fbx", f"{model_name}.fbx") |
|
file_paths.append(model_path) |
|
bbx_values.append(row['bbx']) |
|
|
|
self.caption_to_file = [ |
|
{ |
|
"caption": caption, |
|
"file_path": path, |
|
"bbx": bbx |
|
} |
|
for caption, path, bbx in zip(captions, file_paths, bbx_values) |
|
] |
|
|
|
|
|
self.clip_model, self.clip_prep = transformers.CLIPModel.from_pretrained( |
|
"./ckpts/CLIP-ViT-bigG-14-laion2B-39B-b160k", |
|
low_cpu_mem_usage=True, torch_dtype=torch.float16, |
|
offload_state_dict=True, |
|
), transformers.CLIPProcessor.from_pretrained("./ckpts/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
|
|
|
self.local_embeddings = torch.load("./embeddings/local.pt") |
|
|
|
|
|
def parse_input(self, user_input, max_number_of_objects): |
|
self.user_input = user_input |
|
self.max_number_of_objects = max_number_of_objects |
|
user_proxy, requirements_analyzer, substructure_analyzer, substructure_analyzer_checker, interior_designer, designer_checker = create_parse_agents(self.max_number_of_objects) |
|
|
|
init_groupchat = RequirementGroupChat( |
|
agents=[user_proxy, requirements_analyzer, substructure_analyzer, interior_designer, designer_checker], |
|
messages=[], |
|
max_round=16 |
|
) |
|
|
|
manager = GroupChatManager(groupchat=init_groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_require) |
|
|
|
user_proxy.initiate_chat( |
|
manager, |
|
message=f""" |
|
The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
|
User Input (in triple backquotes): |
|
``` |
|
{self.user_input} |
|
``` |
|
Room layout elements in the room (in triple backquotes): |
|
``` |
|
['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'] |
|
``` |
|
json |
|
""", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.designer_response = json.loads(init_groupchat.messages[-2]["content"]) |
|
self.cot_info["parse_cot"] = self.designer_response["chain_of_thought"] |
|
|
|
|
|
|
|
def retrieve_local_assets(self): |
|
|
|
|
|
print("Locking...") |
|
sys.clip_move_lock = threading.Lock() |
|
print("Locked.") |
|
|
|
if torch.cuda.is_available(): |
|
with sys.clip_move_lock: |
|
self.clip_model.cuda() |
|
torch.set_grad_enabled(False) |
|
|
|
|
|
def preprocess(input_string): |
|
wo_numericals = re.sub(r'\d', '', input_string) |
|
output = wo_numericals.replace("_", " ") |
|
return output |
|
|
|
def retrieve_local(query_embedding, top=1, sim_th=0.5): |
|
query_embedding = F.normalize(query_embedding.detach().cpu(), dim=-1).squeeze() |
|
sims = [] |
|
for embedding in torch.split(self.local_embeddings, 10240): |
|
sims.append(query_embedding @ F.normalize(embedding.float(), dim=-1).T) |
|
sims = torch.cat(sims) |
|
sims, indices = torch.sort(sims, descending=True) |
|
results = [] |
|
for i, sim in zip(indices, sims): |
|
if sim > sim_th: |
|
results.append({ |
|
"caption": self.caption_to_file[i]["caption"], |
|
"file_path": self.caption_to_file[i]["file_path"], |
|
"bbx": self.caption_to_file[i]["bbx"], |
|
"sim": sim.item() |
|
}) |
|
if len(results) >= top: |
|
break |
|
return results |
|
|
|
def retrieve(embedding, top=1, sim_th=0.1, filter_fn=None): |
|
sims = [] |
|
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze() |
|
for chunk in torch.split(self.feats, 10240): |
|
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T) |
|
sims = torch.cat(sims) |
|
sims, idx = torch.sort(sims, descending=True) |
|
sim_mask = sims > sim_th |
|
sims = sims[sim_mask] |
|
idx = idx[sim_mask] |
|
results = [] |
|
for i, sim in zip(idx, sims): |
|
if self.us[i] in self.meta: |
|
if filter_fn is None or filter_fn(self.meta[self.us[i]]): |
|
results.append(dict(self.meta[self.us[i]], sim=sim)) |
|
if len(results) >= top: |
|
break |
|
return results |
|
|
|
def get_filter_fn(): |
|
face_min = 0 |
|
face_max = 34985808 |
|
anim_min = 0 |
|
anim_max = 563 |
|
anim_n = not (anim_min > 0 or anim_max < 563) |
|
face_n = not (face_min > 0 or face_max < 34985808) |
|
filter_fn = lambda x: ( |
|
(anim_n or anim_min <= x['anims'] <= anim_max) |
|
and (face_n or face_min <= x['faces'] <= face_max) |
|
) |
|
return filter_fn |
|
|
|
def get_model_dimensions(file_path): |
|
mesh = trimesh.load(file_path) |
|
bounding_box = mesh.bounding_box.extents |
|
length = bounding_box[0] / 100 |
|
width = bounding_box[2] / 100 |
|
height = bounding_box[1] / 100 |
|
return length, width, height |
|
|
|
|
|
objects = extract_list_from_json(self.designer_response, 'objects') |
|
for obj in objects: |
|
text = preprocess("A high-poly " + obj['object_id']) + f" with {obj['material']} material and in {obj['style']} style, high quality" |
|
device = self.clip_model.device |
|
tn = self.clip_prep( |
|
text=[text], return_tensors='pt', truncation=True, max_length=76 |
|
).to(device) |
|
enc = self.clip_model.get_text_features(**tn).float().cpu() |
|
|
|
retrieved_local = retrieve_local(enc, top=1, sim_th=0.5) |
|
if retrieved_local: |
|
retrieved_obj = retrieved_local[0] |
|
print("Retrieved object: ", retrieved_obj["file_path"]) |
|
|
|
|
|
|
|
|
|
source_file = retrieved_obj["file_path"] |
|
file_extension = os.path.splitext(source_file)[1] |
|
|
|
|
|
|
|
|
|
if retrieved_obj["sim"] > 0.5: |
|
length, width, height = map(float, retrieved_obj["bbx"].split(',')) |
|
obj['bounding_box_size'] = {'Length': length, 'Width': width, 'Height': height} |
|
else: |
|
retrieved_obj = retrieve(enc, top=1, sim_th=0.1, filter_fn=get_filter_fn())[0] |
|
print(f"Retrieved object from Objaverse: {retrieved_obj['u']}") |
|
processes = multiprocessing.cpu_count() |
|
objaverse_objects = objaverse.load_objects( |
|
uids=[retrieved_obj['u']], |
|
download_processes=processes |
|
) |
|
|
|
|
|
|
|
for item_id, file_path in objaverse_objects.items(): |
|
|
|
|
|
|
|
|
|
if retrieved_obj["sim"] > 0.18: |
|
length, width, height = get_model_dimensions(file_path) |
|
obj['bounding_box_size'] = {'Length': length, 'Width': width, 'Height': height} |
|
|
|
self.designer_response['objects'] = objects |
|
print(self.designer_response) |
|
|
|
def create_scene_graph(self): |
|
cot_data_1 = [] |
|
user_proxy, interior_architect, schema_engineer = create_graph_agents() |
|
|
|
scene_graph_groupchat = SceneGraphGroupChat( |
|
agents =[user_proxy, interior_architect, schema_engineer], |
|
messages=[], |
|
max_round=10 |
|
) |
|
|
|
cot_data, json_info, json_data = {}, {}, {} |
|
blocks_designer = extract_list_from_json(self.designer_response, 'objects') |
|
|
|
for d_block in blocks_designer: |
|
object_id = d_block["object_id"] |
|
prompt = str(d_block) |
|
|
|
manager_scene_graph = GroupChatManager(groupchat=scene_graph_groupchat, |
|
llm_config=gpt4_config, |
|
human_input_mode="NEVER", |
|
is_termination_msg=is_termination_msg) |
|
|
|
user_proxy.initiate_chat( |
|
manager_scene_graph, |
|
message=f""" |
|
The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
|
User Input (in triple backquotes): |
|
``` |
|
{self.user_input} |
|
``` |
|
Room layout elements in the room (in triple backquotes): |
|
``` |
|
['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the floor', 'ceiling'] |
|
``` |
|
Previously placed objects in the room (in triple backquotes): |
|
``` |
|
{json_data} |
|
``` |
|
Object to be placed (in triple backticks): |
|
``` |
|
{prompt} |
|
``` |
|
""", |
|
) |
|
|
|
if not json_info: |
|
json_info["objects_in_room"] = [] |
|
json_info["objects_in_room"] += json.loads(scene_graph_groupchat.messages[-2]["content"])["objects_in_room"] |
|
object_data = json.loads(scene_graph_groupchat.messages[-2]["content"])["objects_in_room"][0] |
|
|
|
if 'new_object_id' in object_data: |
|
del object_data['new_object_id'] |
|
|
|
json_data[str(object_id)] = object_data |
|
|
|
if str(object_id) not in cot_data: |
|
cot_data[str(object_id)] = [] |
|
|
|
indices_to_collect = list(range(1, len(scene_graph_groupchat.messages), 2)) |
|
for idx in indices_to_collect: |
|
cot_data[str(object_id)].append(json.loads(scene_graph_groupchat.messages[idx]["content"])["chain_of_thought"]) |
|
|
|
user_proxy.reset(), interior_architect.reset(), schema_engineer.reset(), scene_graph_groupchat.reset() |
|
|
|
self.cot_info["scene_graph_cot"] = cot_data |
|
self.scene_graph = json_info |
|
self.conflict_data = [] |
|
|
|
|
|
scene_graph = preprocess_scene_graph(json_info["objects_in_room"], cot_data_1) |
|
G = build_graph(scene_graph) |
|
G = remove_unnecessary_edges(G, cot_data_1) |
|
G, scene_graph = handle_under_prepositions(G, scene_graph, cot_data_1) |
|
conflicts = get_conflicts(G, scene_graph, cot_data_1) |
|
|
|
print("-------------------CONFLICTS-------------------") |
|
for conflict in conflicts: |
|
print(conflict) |
|
print("\n\n") |
|
self.conflict_data.append(conflicts) |
|
|
|
user_proxy, spatial_corrector_agent, json_schema_debugger, object_deletion_agent = get_corrector_agents() |
|
|
|
while len(conflicts) > 0: |
|
spatial_corrector_agent.reset(), json_schema_debugger.reset() |
|
groupchat = LayoutCorrectorGroupChat( |
|
agents =[user_proxy, spatial_corrector_agent, json_schema_debugger], |
|
messages=[], |
|
max_round=15 |
|
) |
|
manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
|
user_proxy.initiate_chat( |
|
manager, |
|
message=f""" |
|
{conflicts[0]} |
|
""", |
|
) |
|
correction = groupchat.messages[-2] |
|
pattern = r'```json\s*([^`]+)\s*```' |
|
match = re.search(pattern, correction["content"], re.DOTALL).group(1) |
|
correction_json = json.loads(match) |
|
self.conflict_data.append(correction_json) |
|
corr_obj = get_object_from_scene_graph(correction_json["corrected_object"]["new_object_id"], scene_graph) |
|
corr_obj["is_on_the_floor"] = correction_json["corrected_object"]["is_on_the_floor"] |
|
corr_obj["facing"] = correction_json["corrected_object"]["facing"] |
|
corr_obj["placement"] = correction_json["corrected_object"]["placement"] |
|
G = build_graph(scene_graph) |
|
conflicts = get_conflicts(G, scene_graph, cot_data_1) |
|
|
|
size_conflicts = get_size_conflicts(G, scene_graph, cot_data_1, self.user_input, self.room_priors) |
|
|
|
print("-------------------SIZE CONFLICTS-------------------") |
|
for conflict in size_conflicts: |
|
print(conflict) |
|
print("\n\n") |
|
self.conflict_data.append(size_conflicts) |
|
|
|
while len(size_conflicts) > 0: |
|
object_deletion_agent.reset() |
|
groupchat = ObjectDeletionGroupChat( |
|
agents =[user_proxy, object_deletion_agent], |
|
messages=[], |
|
max_round=2 |
|
) |
|
manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
|
user_proxy.initiate_chat( |
|
manager, |
|
message=f""" |
|
{size_conflicts[0]} |
|
""", |
|
) |
|
correction = groupchat.messages[-1] |
|
correction_json = json.loads(correction["content"]) |
|
object_to_delete = correction_json["object_to_delete"] |
|
descendants = nx.descendants(G, object_to_delete) |
|
objs_to_delete = descendants.union({object_to_delete}) |
|
print("Objs to Delete: ", objs_to_delete) |
|
self.conflict_data.append(f"Objs to Delete: {objs_to_delete}") |
|
scene_graph = [x for x in scene_graph if x["new_object_id"] not in objs_to_delete] |
|
for obj in objs_to_delete: |
|
G.remove_node(obj) |
|
|
|
size_conflicts = get_size_conflicts(G, scene_graph, cot_data_1, self.user_input, self.room_priors) |
|
|
|
self.scene_graph["objects_in_room"] = scene_graph |
|
|
|
def summary_language(self): |
|
user_proxy, language_architect = language_summary_agents() |
|
|
|
groupchat = LanguageGroupChat( |
|
agents=[user_proxy, language_architect], |
|
messages=[], |
|
max_round=2 |
|
) |
|
|
|
manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
|
|
|
user_proxy.initiate_chat( |
|
manager, |
|
message=f""" |
|
The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
|
User Input (in triple backquotes): |
|
``` |
|
**chain of thought for requirements_analyzer, substructure_analyzer and interior_designer** |
|
{self.cot_info["parse_cot"]} |
|
``` |
|
**chain of thought for object placement** |
|
{self.cot_info["scene_graph_cot"]} |
|
``` |
|
**conflict data** |
|
{self.conflict_data} |
|
``` |
|
**scene graph** |
|
{self.scene_graph} |
|
``` |
|
Room layout elements in the room (in triple backquotes): |
|
``` |
|
['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'] |
|
``` |
|
json |
|
""", |
|
) |
|
|
|
self.language_sum = groupchat.messages[-1]["content"] |
|
|
|
def create_layout(self, debug=False): |
|
|
|
|
|
cot_data = [] |
|
G = build_graph(self.scene_graph["objects_in_room"]) |
|
nodes = G.nodes() |
|
|
|
cot_data.append("Calculate constraint area for non-layout objects only.") |
|
for node in nodes: |
|
if node not in self.layout_elements: |
|
cluster_size, _ = get_cluster_size(node, G, self.scene_graph["objects_in_room"], cot_data) |
|
node_obj = get_object_from_scene_graph(node, self.scene_graph["objects_in_room"]) |
|
cluster_size = {"x_neg" : cluster_size["left of"], "x_pos" : cluster_size["right of"], "y_neg" : cluster_size["behind"], "y_pos" : cluster_size["in front"]} |
|
node_obj["cluster"] = {"constraint_area" : cluster_size} |
|
cot_data.append(f"The constraint area for {node} is {cluster_size}.") |
|
|
|
self.scene_graph = self.scene_graph["objects_in_room"] + self.room_priors |
|
|
|
prior_ids = ["south_wall", "north_wall", "east_wall", "west_wall", "ceiling", "middle of the room"] |
|
point_bbox = dict.fromkeys([item["new_object_id"] for item in self.scene_graph], False) |
|
|
|
|
|
for item in self.scene_graph: |
|
if item["new_object_id"] in prior_ids: |
|
continue |
|
possible_pos = get_possible_positions(item["new_object_id"], self.scene_graph, self.room_dimensions, cot_data) |
|
|
|
overlap = None |
|
if len(possible_pos) == 1: |
|
overlap = possible_pos[0] |
|
elif len(possible_pos) > 1: |
|
overlap = possible_pos[0] |
|
for pos in possible_pos[1:]: |
|
overlap = calculate_overlap(overlap, pos) |
|
|
|
if overlap is not None and is_point_bbox(overlap) and len(possible_pos) > 0: |
|
item["position"] = {"x" : overlap[0], "y" : overlap[2], "z" : overlap[4]} |
|
point_bbox[item["new_object_id"]] = True |
|
|
|
scene_graph_wo_layout = [item for item in self.scene_graph if item["new_object_id"] not in self.layout_elements] |
|
|
|
depth_scene_graph = get_depth(scene_graph_wo_layout) |
|
max_depth = max(depth_scene_graph.values()) |
|
|
|
topological_order = get_topological_ordering(scene_graph_wo_layout) |
|
topological_order = [item for item in topological_order if item not in self.layout_elements] |
|
|
|
d = 1 |
|
count = 0 |
|
while d <= max_depth and count < 20: |
|
count += 1 |
|
error_flag = False |
|
|
|
nodes = [node for node in topological_order if depth_scene_graph[node] == d] |
|
if debug: |
|
print(f"Nodes at depth {d}: ", nodes) |
|
|
|
errors = {} |
|
|
|
cot_data.append(f"Place objects: {[node for node in nodes]}.") |
|
for node in nodes: |
|
if point_bbox[node]: |
|
continue |
|
|
|
obj = next(item for item in scene_graph_wo_layout if item["new_object_id"] == node) |
|
cot_data.append(f"Place the object {obj['new_object_id']} at the depth {d}.") |
|
errors = place_object(obj, self.scene_graph, self.room_dimensions, cot_data, errors={}, debug=debug) |
|
|
|
if debug: |
|
print(f"Errors for {obj['new_object_id']}: ", errors) |
|
|
|
|
|
if errors: |
|
if d > 1: |
|
d -= 1 |
|
cot_data.append(f"Errors occur for {obj['new_object_id']}: {errors}. Reduce depth to {d}.") |
|
if debug: |
|
print("Reducing depth to: ", d) |
|
else: |
|
cot_data.append(f"Errors occur for {obj['new_object_id']} with depth 1: {errors}. The layout creation failed.") |
|
print(f"Errors occur for {obj['new_object_id']} with depth 1: {errors}. The layout creation failed.") |
|
self.calculation_data = [] |
|
return errors |
|
|
|
error_flag = True |
|
cot_data.append(f"Delete positions for objects at or beyond the current depth {d} in order to reposition the objects.") |
|
for del_item in scene_graph_wo_layout: |
|
if depth_scene_graph[del_item["new_object_id"]] >= d: |
|
if "position" in del_item.keys() and not point_bbox[del_item["new_object_id"]]: |
|
if debug: |
|
print("Deleting position for: ", del_item["new_object_id"]) |
|
del del_item["position"] |
|
errors = {} |
|
break |
|
|
|
|
|
|
|
if not error_flag: |
|
d += 1 |
|
|
|
cot_data.append("Save the scene graph.") |
|
self.calculation_data = cot_data |
|
print(cot_data) |
|
print("\n") |
|
|
|
os.makedirs("./results", exist_ok=True) |
|
jsonname = re.sub(r'[^a-zA-Z0-9]', '_', self.user_input) + '.json' |
|
self.result_file = os.path.join("./results", jsonname) |
|
with open(self.result_file, "w") as file: |
|
json.dump(self.scene_graph, file, indent=4) |
|
|
|
def summary_calculation(self): |
|
if self.calculation_data: |
|
user_proxy, calculation_architect = calculation_summary_agents() |
|
groupchat = CalculationGroupChat( |
|
agents=[user_proxy, calculation_architect], |
|
messages=[], |
|
max_round=2 |
|
) |
|
manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
|
|
|
user_proxy.initiate_chat( |
|
manager, |
|
message=f""" |
|
The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
|
User Input (in triple backquotes): |
|
``` |
|
{self.calculation_data} |
|
``` |
|
Room layout elements in the room (in triple backquotes): |
|
``` |
|
['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'] |
|
``` |
|
json |
|
""", |
|
) |
|
|
|
self.calculation_sum = groupchat.messages[-1]["content"] |
|
|
|
os.makedirs("./Results_data", exist_ok=True) |
|
filename = re.sub(r'[^a-zA-Z0-9]', '_', self.user_input) + '.md' |
|
full_path = os.path.join("./Results_data", filename) |
|
with open(full_path, 'w', encoding='utf-8') as file: |
|
file.write(self.language_sum) |
|
file.write('\n\n## 6. **Object Placement**\n') |
|
file.write(self.calculation_sum) |
|
else: |
|
pass |
|
|