LD-T3D / app.py
yuanze1024's picture
init
f15a1cd
raw
history blame
5.54 kB
import os
import gradio as gr
import numpy as np
import torch
import functools
from datasets import load_dataset
from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder
# os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000'
# os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000'
MAX_BATCH_SIZE = 16
MAX_QUEUE_SIZE = 10
MAX_K_RETRIEVAL = 20
cache_dir = "./.cache"
encoder = Uni3dEmbeddingEncoder(cache_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
source_id_list = torch.load("data/source_id_list.pt")
source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)}
dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir)
@functools.lru_cache()
def get_embedding(option, modality, angle=None):
save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt'
if os.path.exists(save_path):
return torch.load(save_path)
else:
return gr.Error(f"Embedding file not found: {save_path}")
def predict(xb, xq, top_k):
xb = xb.to(xq.device)
sim = xq @ xb.T # (nq, nb)
_, indices = sim.topk(k=top_k, largest=True)
return indices
def get_image(index):
return dataset[index]["image"]
def retrieve_3D_models(textual_query, top_k, modality_list):
if textual_query == "":
raise gr.Error("Please enter a textual query")
if len(textual_query.split()) > 20:
gr.Warning("Retrieval result may be inaccurate due to long textual query")
if len(modality_list) == 0:
raise gr.Error("Please select at least one modality")
def _retrieve_3D_models(query, top_k, modals:list):
option = "uni3d"
op = "add"
is_text = True if "text" in modals else False
is_3D = True if "3D" in modals else False
if is_text:
modals.remove("text")
if is_3D:
modals.remove("3D")
angles = modals
# get base embeddings
embeddings = []
if is_text:
embeddings.append(get_embedding(option, "text"))
if len(angles) > 0:
for angle in angles:
embeddings.append(get_embedding(option, "image", angle=angle))
if is_3D:
embeddings.append(get_embedding(option, "3D"))
## fuse base embeddings
if len(embeddings) > 1:
if op == "concat":
embeddings = torch.cat(embeddings, dim=-1)
elif op == "add":
embeddings = sum(embeddings)
else:
raise ValueError(f"Unsupported operation: {op}")
embeddings /= embeddings.norm(dim=-1, keepdim=True)
else:
embeddings = embeddings[0]
# encode query embeddings
xq = encoder.encode_query(query)
if op == "concat":
xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb
xq /= xq.norm(dim=-1, keepdim=True)
pred_ind_list = predict(embeddings, xq, top_k)
return pred_ind_list[0].cpu().tolist() # we have only one query
indices = _retrieve_3D_models(textual_query, top_k, modality_list)
return [get_image(index) for index in indices]
def launch():
with gr.Blocks() as demo:
with gr.Row():
textual_query = gr.Textbox(label="Textual Query", autofocus=True,
placeholder="A chair with a wooden frame and a cushioned seat")
modality_list = gr.CheckboxGroup(label="Modality List", value=[],
choices=["text", "front", "back", "left", "right", "above",
"below", "diag_above", "diag_below", "3D"])
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result",
value=5, scale=2)
run = gr.Button("Search", scale=1)
clear_button = gr.ClearButton(scale=1)
with gr.Row():
output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil")
run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output,
# batch=True, max_batch_size=MAX_BATCH_SIZE
)
clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output])
examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]],
inputs=[textual_query, top_k, modality_list],
# cache_examples=True,
outputs=output,
fn=retrieve_3D_models)
demo.queue(max_size=10)
# os.environ.pop('HTTP_PROXY')
# os.environ.pop('HTTPS_PROXY')
demo.launch(server_name='0.0.0.0')
if __name__ == "__main__":
launch()
# print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"])))