|
import openshape |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
import json |
|
import numpy as np |
|
import transformers |
|
import threading |
|
import multiprocessing |
|
import sys, os, shutil |
|
import pandas as pd |
|
from torch.nn import functional as F |
|
import re |
|
|
|
|
|
print("Device: ", torch.cuda.get_device_name(0)) |
|
|
|
|
|
pc_encoder = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb') |
|
|
|
local_assets = pd.read_excel("/root/IDesign/copy.xlsx", skiprows=2) |
|
|
|
captions = local_assets["caption_english"].tolist() |
|
file_paths = [] |
|
bbx_values = [] |
|
for index, row in local_assets.iterrows(): |
|
model_name = row['name_en'] |
|
model_path = os.path.join("/root/IDesign/lvm_2032fbx", f"{model_name}.fbx") |
|
file_paths.append(model_path) |
|
bbx_values.append(row['bbx']) |
|
|
|
caption_to_file = [ |
|
{ |
|
"caption": caption, |
|
"file_path": path, |
|
"bbx": bbx |
|
} |
|
for caption, path, bbx in zip(captions, file_paths, bbx_values) |
|
] |
|
|
|
def load_openclip(): |
|
print("Locking...") |
|
sys.clip_move_lock = threading.Lock() |
|
print("Locked.") |
|
clip_model, clip_prep = transformers.CLIPModel.from_pretrained( |
|
"/root/IDesign/CLIP-ViT-bigG-14-laion2B-39B-b160k", |
|
low_cpu_mem_usage=True, torch_dtype=torch.float16, |
|
offload_state_dict=True, |
|
), transformers.CLIPProcessor.from_pretrained("/root/IDesign/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
|
if torch.cuda.is_available(): |
|
with sys.clip_move_lock: |
|
clip_model.cuda() |
|
return clip_model, clip_prep |
|
|
|
clip_model, clip_prep = load_openclip() |
|
torch.set_grad_enabled(False) |
|
|
|
def preprocess(input_string): |
|
wo_numericals = re.sub(r'\d', '', input_string) |
|
output = wo_numericals.replace("_", " ") |
|
return output |
|
|
|
def compute_local_embeddings(captions): |
|
device = clip_model.device |
|
embeddings = [] |
|
for item in captions: |
|
text = preprocess(item["caption"]) |
|
inputs = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device) |
|
embedding = clip_model.get_text_features(**inputs).float().cpu() |
|
embeddings.append(embedding) |
|
return torch.cat(embeddings, dim=0) |
|
|
|
local_embeddings = compute_local_embeddings(caption_to_file) |
|
|
|
def retrieve_local(query_embedding, top=1, sim_th=0.0): |
|
query_embedding = F.normalize(query_embedding.detach().cpu(), dim=-1).squeeze() |
|
sims = [] |
|
for embedding in torch.split(local_embeddings, 10240): |
|
sims.append(query_embedding @ F.normalize(embedding.float(), dim=-1).T) |
|
sims = torch.cat(sims) |
|
sims, indices = torch.sort(sims, descending=True) |
|
results = [] |
|
for i, sim in zip(indices, sims): |
|
if sim > sim_th: |
|
results.append({ |
|
"caption": caption_to_file[i]["caption"], |
|
"file_path": caption_to_file[i]["file_path"], |
|
"bbx": caption_to_file[i]["bbx"], |
|
"sim": sim.item() |
|
}) |
|
if len(results) >= top: |
|
break |
|
return results |
|
|
|
file_path = "scene_graph.json" |
|
with open(file_path, "r") as file: |
|
objects_in_room = json.load(file) |
|
|
|
for obj_in_room in objects_in_room: |
|
if "style" in obj_in_room and "material" in obj_in_room: |
|
style, material = obj_in_room['style'], obj_in_room["material"] |
|
else: |
|
continue |
|
text = preprocess("A high-poly " + obj_in_room['new_object_id']) + f" with {material} material and in {style} style, high quality" |
|
device = clip_model.device |
|
tn = clip_prep( |
|
text=[text], return_tensors='pt', truncation=True, max_length=76 |
|
).to(device) |
|
enc = clip_model.get_text_features(**tn).float().cpu() |
|
|
|
retrieved_obj = retrieve_local(enc, top=1, sim_th=0.1)[0] |
|
print("Retrieved object: ", retrieved_obj["file_path"]) |
|
print("Bbox: ", retrieved_obj["bbx"]) |
|
|
|
destination_folder = os.path.join(os.getcwd(), f"Assets/") |
|
if not os.path.exists(destination_folder): |
|
os.makedirs(destination_folder) |
|
source_file = retrieved_obj["file_path"] |
|
file_extension = os.path.splitext(source_file)[1] |
|
destination_path = os.path.join(destination_folder, f"{obj_in_room['new_object_id']}{file_extension}") |
|
shutil.copy(source_file, destination_path) |
|
print(f"File moved to {destination_path}") |