""" app.py An interactive demo of text-guided shape generation. """ from pathlib import Path from typing import Literal import gradio as gr import plotly.graph_objects as go from salad.utils.spaghetti_util import ( get_mesh_from_spaghetti, generate_zc_from_sj_gaus, load_mesher, load_spaghetti, ) import hydra from omegaconf import OmegaConf import torch from pytorch_lightning import seed_everything def load_model( model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"], device, ): checkpoint_dir = Path(__file__).parent / "checkpoints" c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml") model = hydra.utils.instantiate(c) ckpt = torch.load( checkpoint_dir / f"{model_class}/state_only.ckpt", map_location=device, ) model.load_state_dict(ckpt) model.eval() for p in model.parameters(): p.requires_grad_(False) model = model.to(device) return model def run_inference(prompt: str): """The entry point of the demo.""" device: torch.device = torch.device("cuda") """Device to run the demo on.""" seed: int = 63 """Random seed for reproducibility.""" # set random seed seed_everything(seed) # load SPAGHETTI and mesher spaghetti = load_spaghetti(device) mesher = load_mesher(device) # load SALAD lang_phase1_model = load_model("lang_phase1", device) lang_phase2_model = load_model("phase2", device) lang_phase1_model._build_dataset("val") # run phase 1 extrinsics = lang_phase1_model.sampling_gaussians([prompt]) # run phase 2 intrinsics = lang_phase2_model.sample(extrinsics) # generate mesh zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics) vertices, faces = get_mesh_from_spaghetti( spaghetti, mesher, zcs[0], res=256, ) # plot figure = go.Figure( data=[ go.Mesh3d( x=vertices[:, 0], # flip front-back y=-vertices[:, 2], z=vertices[:, 1], i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], color="gray", ) ], layout=dict( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), ) ), ) return figure if __name__ == "__main__": title = "SALAD: Text-Guided Shape Generation" description_text = ''' This demo features text-guided 3D shape generation from our work SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023. Please refer to our project page for details. ''' # create UI with gr.Blocks(title=title) as demo: # description of demo gr.Markdown(description_text) # inputs with gr.Row(): prompt_textbox = gr.Textbox(placeholder="Describe a chair.") with gr.Row(): run_button = gr.Button(value="Generate") clear_button = gr.ClearButton( value="Clear", components=[prompt_textbox], ) # display examples examples = gr.Examples( examples=[ "an office chair", "a chair with armrests", "a chair without armrests", ], inputs=[prompt_textbox], ) # outputs mesh_viewport = gr.Plot() # run inference run_button.click( run_inference, inputs=[prompt_textbox], outputs=[mesh_viewport], ) demo.queue(max_size=30) demo.launch()