code / retrieve_local.py
MSheng-Lee's picture
Upload folder using huggingface_hub
f20b100 verified
import openshape
from huggingface_hub import hf_hub_download
import torch
import json
import numpy as np
import transformers
import threading
import multiprocessing
import sys, os, shutil
import pandas as pd
from torch.nn import functional as F
import re
# Print device
print("Device: ", torch.cuda.get_device_name(0))
# Load the Pointcloud Encoder
pc_encoder = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')
local_assets = pd.read_excel("/root/IDesign/copy.xlsx", skiprows=2)
captions = local_assets["caption_english"].tolist()
file_paths = []
bbx_values = []
for index, row in local_assets.iterrows():
model_name = row['name_en']
model_path = os.path.join("/root/IDesign/lvm_2032fbx", f"{model_name}.fbx")
file_paths.append(model_path)
bbx_values.append(row['bbx'])
caption_to_file = [
{
"caption": caption,
"file_path": path,
"bbx": bbx
}
for caption, path, bbx in zip(captions, file_paths, bbx_values)
]
def load_openclip():
print("Locking...")
sys.clip_move_lock = threading.Lock()
print("Locked.")
clip_model, clip_prep = transformers.CLIPModel.from_pretrained(
"/root/IDesign/CLIP-ViT-bigG-14-laion2B-39B-b160k",
low_cpu_mem_usage=True, torch_dtype=torch.float16,
offload_state_dict=True,
), transformers.CLIPProcessor.from_pretrained("/root/IDesign/CLIP-ViT-bigG-14-laion2B-39B-b160k")
if torch.cuda.is_available():
with sys.clip_move_lock:
clip_model.cuda()
return clip_model, clip_prep
clip_model, clip_prep = load_openclip()
torch.set_grad_enabled(False)
def preprocess(input_string):
wo_numericals = re.sub(r'\d', '', input_string)
output = wo_numericals.replace("_", " ")
return output
def compute_local_embeddings(captions):
device = clip_model.device
embeddings = []
for item in captions:
text = preprocess(item["caption"])
inputs = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
embedding = clip_model.get_text_features(**inputs).float().cpu()
embeddings.append(embedding)
return torch.cat(embeddings, dim=0)
local_embeddings = compute_local_embeddings(caption_to_file)
def retrieve_local(query_embedding, top=1, sim_th=0.0):
query_embedding = F.normalize(query_embedding.detach().cpu(), dim=-1).squeeze()
sims = []
for embedding in torch.split(local_embeddings, 10240):
sims.append(query_embedding @ F.normalize(embedding.float(), dim=-1).T)
sims = torch.cat(sims)
sims, indices = torch.sort(sims, descending=True)
results = []
for i, sim in zip(indices, sims):
if sim > sim_th:
results.append({
"caption": caption_to_file[i]["caption"],
"file_path": caption_to_file[i]["file_path"],
"bbx": caption_to_file[i]["bbx"],
"sim": sim.item()
})
if len(results) >= top:
break
return results
file_path = "scene_graph.json"
with open(file_path, "r") as file:
objects_in_room = json.load(file)
for obj_in_room in objects_in_room:
if "style" in obj_in_room and "material" in obj_in_room:
style, material = obj_in_room['style'], obj_in_room["material"]
else:
continue
text = preprocess("A high-poly " + obj_in_room['new_object_id']) + f" with {material} material and in {style} style, high quality"
device = clip_model.device
tn = clip_prep(
text=[text], return_tensors='pt', truncation=True, max_length=76
).to(device)
enc = clip_model.get_text_features(**tn).float().cpu()
retrieved_obj = retrieve_local(enc, top=1, sim_th=0.1)[0]
print("Retrieved object: ", retrieved_obj["file_path"])
print("Bbox: ", retrieved_obj["bbx"])
destination_folder = os.path.join(os.getcwd(), f"Assets/")
if not os.path.exists(destination_folder):
os.makedirs(destination_folder)
source_file = retrieved_obj["file_path"]
file_extension = os.path.splitext(source_file)[1]
destination_path = os.path.join(destination_folder, f"{obj_in_room['new_object_id']}{file_extension}")
shutil.copy(source_file, destination_path)
print(f"File moved to {destination_path}")