File size: 4,265 Bytes
f20b100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import openshape
from huggingface_hub import hf_hub_download
import torch
import json
import numpy as np
import transformers
import threading
import multiprocessing
import sys, os, shutil
import pandas as pd
from torch.nn import functional as F
import re

# Print device
print("Device: ", torch.cuda.get_device_name(0))

# Load the Pointcloud Encoder
pc_encoder = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')

local_assets = pd.read_excel("/root/IDesign/copy.xlsx", skiprows=2)

captions = local_assets["caption_english"].tolist()
file_paths = []
bbx_values = []
for index, row in local_assets.iterrows():
    model_name = row['name_en']
    model_path = os.path.join("/root/IDesign/lvm_2032fbx", f"{model_name}.fbx")
    file_paths.append(model_path)
    bbx_values.append(row['bbx'])

caption_to_file = [
    {
        "caption": caption,
        "file_path": path,
        "bbx": bbx
    }
    for caption, path, bbx in zip(captions, file_paths, bbx_values)
]

def load_openclip():
    print("Locking...")
    sys.clip_move_lock = threading.Lock()
    print("Locked.")
    clip_model, clip_prep = transformers.CLIPModel.from_pretrained(
        "/root/IDesign/CLIP-ViT-bigG-14-laion2B-39B-b160k",
        low_cpu_mem_usage=True, torch_dtype=torch.float16,
        offload_state_dict=True,
    ), transformers.CLIPProcessor.from_pretrained("/root/IDesign/CLIP-ViT-bigG-14-laion2B-39B-b160k")
    if torch.cuda.is_available():
        with sys.clip_move_lock:
            clip_model.cuda()
    return clip_model, clip_prep

clip_model, clip_prep = load_openclip()
torch.set_grad_enabled(False)

def preprocess(input_string):
    wo_numericals = re.sub(r'\d', '', input_string)
    output = wo_numericals.replace("_", " ")
    return output

def compute_local_embeddings(captions):
    device = clip_model.device
    embeddings = []
    for item in captions:
        text = preprocess(item["caption"])
        inputs = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
        embedding = clip_model.get_text_features(**inputs).float().cpu()
        embeddings.append(embedding)
    return torch.cat(embeddings, dim=0)

local_embeddings = compute_local_embeddings(caption_to_file)

def retrieve_local(query_embedding, top=1, sim_th=0.0):
    query_embedding = F.normalize(query_embedding.detach().cpu(), dim=-1).squeeze()
    sims = []
    for embedding in torch.split(local_embeddings, 10240):
        sims.append(query_embedding @ F.normalize(embedding.float(), dim=-1).T)
    sims = torch.cat(sims)
    sims, indices = torch.sort(sims, descending=True)
    results = []
    for i, sim in zip(indices, sims):
        if sim > sim_th:
            results.append({
                "caption": caption_to_file[i]["caption"],
                "file_path": caption_to_file[i]["file_path"],
                "bbx": caption_to_file[i]["bbx"],
                "sim": sim.item()
            })
            if len(results) >= top:
                break
    return results

file_path = "scene_graph.json"
with open(file_path, "r") as file:
    objects_in_room = json.load(file)

for obj_in_room in objects_in_room:
    if "style" in obj_in_room and "material" in obj_in_room:
        style, material = obj_in_room['style'], obj_in_room["material"]
    else:
        continue
    text = preprocess("A high-poly " + obj_in_room['new_object_id']) + f" with {material} material and in {style} style, high quality"
    device = clip_model.device
    tn = clip_prep(
        text=[text], return_tensors='pt', truncation=True, max_length=76
    ).to(device)
    enc = clip_model.get_text_features(**tn).float().cpu()

    retrieved_obj = retrieve_local(enc, top=1, sim_th=0.1)[0]
    print("Retrieved object: ", retrieved_obj["file_path"])
    print("Bbox: ", retrieved_obj["bbx"])

    destination_folder = os.path.join(os.getcwd(), f"Assets/")
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)
    source_file = retrieved_obj["file_path"]
    file_extension = os.path.splitext(source_file)[1]
    destination_path = os.path.join(destination_folder, f"{obj_in_room['new_object_id']}{file_extension}")
    shutil.copy(source_file, destination_path)
    print(f"File moved to {destination_path}")