Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import torch | |
import functools | |
from datasets import load_dataset | |
from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder | |
# os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000' | |
# os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000' | |
MAX_BATCH_SIZE = 16 | |
MAX_QUEUE_SIZE = 10 | |
MAX_K_RETRIEVAL = 20 | |
cache_dir = "./.cache" | |
encoder = Uni3dEmbeddingEncoder(cache_dir) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
source_id_list = torch.load("data/source_id_list.pt") | |
source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)} | |
dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir) | |
def get_embedding(option, modality, angle=None): | |
save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt' | |
if os.path.exists(save_path): | |
return torch.load(save_path) | |
else: | |
return gr.Error(f"Embedding file not found: {save_path}") | |
def predict(xb, xq, top_k): | |
xb = xb.to(xq.device) | |
sim = xq @ xb.T # (nq, nb) | |
_, indices = sim.topk(k=top_k, largest=True) | |
return indices | |
def get_image(index): | |
return dataset[index]["image"] | |
def retrieve_3D_models(textual_query, top_k, modality_list): | |
if textual_query == "": | |
raise gr.Error("Please enter a textual query") | |
if len(textual_query.split()) > 20: | |
gr.Warning("Retrieval result may be inaccurate due to long textual query") | |
if len(modality_list) == 0: | |
raise gr.Error("Please select at least one modality") | |
def _retrieve_3D_models(query, top_k, modals:list): | |
option = "uni3d" | |
op = "add" | |
is_text = True if "text" in modals else False | |
is_3D = True if "3D" in modals else False | |
if is_text: | |
modals.remove("text") | |
if is_3D: | |
modals.remove("3D") | |
angles = modals | |
# get base embeddings | |
embeddings = [] | |
if is_text: | |
embeddings.append(get_embedding(option, "text")) | |
if len(angles) > 0: | |
for angle in angles: | |
embeddings.append(get_embedding(option, "image", angle=angle)) | |
if is_3D: | |
embeddings.append(get_embedding(option, "3D")) | |
## fuse base embeddings | |
if len(embeddings) > 1: | |
if op == "concat": | |
embeddings = torch.cat(embeddings, dim=-1) | |
elif op == "add": | |
embeddings = sum(embeddings) | |
else: | |
raise ValueError(f"Unsupported operation: {op}") | |
embeddings /= embeddings.norm(dim=-1, keepdim=True) | |
else: | |
embeddings = embeddings[0] | |
# encode query embeddings | |
xq = encoder.encode_query(query) | |
if op == "concat": | |
xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb | |
xq /= xq.norm(dim=-1, keepdim=True) | |
pred_ind_list = predict(embeddings, xq, top_k) | |
return pred_ind_list[0].cpu().tolist() # we have only one query | |
indices = _retrieve_3D_models(textual_query, top_k, modality_list) | |
return [get_image(index) for index in indices] | |
def launch(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
textual_query = gr.Textbox(label="Textual Query", autofocus=True, | |
placeholder="A chair with a wooden frame and a cushioned seat") | |
modality_list = gr.CheckboxGroup(label="Modality List", value=[], | |
choices=["text", "front", "back", "left", "right", "above", | |
"below", "diag_above", "diag_below", "3D"]) | |
with gr.Row(): | |
top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result", | |
value=5, scale=2) | |
run = gr.Button("Search", scale=1) | |
clear_button = gr.ClearButton(scale=1) | |
with gr.Row(): | |
output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil") | |
run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output, | |
# batch=True, max_batch_size=MAX_BATCH_SIZE | |
) | |
clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output]) | |
examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]], | |
["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]], | |
["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]], | |
inputs=[textual_query, top_k, modality_list], | |
# cache_examples=True, | |
outputs=output, | |
fn=retrieve_3D_models) | |
demo.queue(max_size=10) | |
# os.environ.pop('HTTP_PROXY') | |
# os.environ.pop('HTTPS_PROXY') | |
demo.launch(server_name='0.0.0.0') | |
if __name__ == "__main__": | |
launch() | |
# print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"]))) |