import os import gradio as gr import numpy as np import torch import functools from datasets import load_dataset from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000' os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000' MAX_BATCH_SIZE = 16 MAX_QUEUE_SIZE = 10 MAX_K_RETRIEVAL = 20 cache_dir = "./.cache" encoder = Uni3dEmbeddingEncoder(cache_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") source_id_list = torch.load("data/source_id_list.pt") source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)} dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir) @functools.lru_cache() def get_embedding(option, modality, angle=None): save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt' if os.path.exists(save_path): return torch.load(save_path) else: return gr.Error(f"Embedding file not found: {save_path}") def predict(xb, xq, top_k): xb = xb.to(xq.device) sim = xq @ xb.T # (nq, nb) _, indices = sim.topk(k=top_k, largest=True) return indices def get_image(index): return dataset[index]["image"] def retrieve_3D_models(textual_query, top_k, modality_list): if textual_query == "": raise gr.Error("Please enter a textual query") if len(textual_query.split()) > 20: gr.Warning("Retrieval result may be inaccurate due to long textual query") if len(modality_list) == 0: raise gr.Error("Please select at least one modality") def _retrieve_3D_models(query, top_k, modals:list): option = "uni3d" op = "add" is_text = True if "text" in modals else False is_3D = True if "3D" in modals else False if is_text: modals.remove("text") if is_3D: modals.remove("3D") angles = modals # get base embeddings embeddings = [] if is_text: embeddings.append(get_embedding(option, "text")) if len(angles) > 0: for angle in angles: embeddings.append(get_embedding(option, "image", angle=angle)) if is_3D: embeddings.append(get_embedding(option, "3D")) ## fuse base embeddings if len(embeddings) > 1: if op == "concat": embeddings = torch.cat(embeddings, dim=-1) elif op == "add": embeddings = sum(embeddings) else: raise ValueError(f"Unsupported operation: {op}") embeddings /= embeddings.norm(dim=-1, keepdim=True) else: embeddings = embeddings[0] # encode query embeddings xq = encoder.encode_query(query) if op == "concat": xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb xq /= xq.norm(dim=-1, keepdim=True) pred_ind_list = predict(embeddings, xq, top_k) return pred_ind_list[0].cpu().tolist() # we have only one query indices = _retrieve_3D_models(textual_query, top_k, modality_list) return [get_image(index) for index in indices] def launch(): with gr.Blocks() as demo: with gr.Row(): textual_query = gr.Textbox(label="Textual Query", autofocus=True, placeholder="A chair with a wooden frame and a cushioned seat") modality_list = gr.CheckboxGroup(label="Modality List", value=[], choices=["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]) with gr.Row(): top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result", value=5, scale=2) run = gr.Button("Search", scale=1) clear_button = gr.ClearButton(scale=1) with gr.Row(): output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil") run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output, # batch=True, max_batch_size=MAX_BATCH_SIZE ) clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output]) examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]], ["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]], ["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]], inputs=[textual_query, top_k, modality_list], # cache_examples=True, outputs=output, fn=retrieve_3D_models) demo.queue(max_size=10) os.environ.pop('HTTP_PROXY') os.environ.pop('HTTPS_PROXY') demo.launch(server_name='0.0.0.0') if __name__ == "__main__": launch() # print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"])))