""" app.py An interactive demo of text-guided shape generation. """ from pathlib import Path from typing import Literal import gradio as gr import plotly.graph_objects as go from salad.utils.spaghetti_util import ( get_mesh_from_spaghetti, generate_zc_from_sj_gaus, load_mesher, load_spaghetti, ) import hydra from omegaconf import OmegaConf import torch from pytorch_lightning import seed_everything def load_model( model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"], device, ): checkpoint_dir = Path(__file__).parent / "checkpoints" c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml") model = hydra.utils.instantiate(c) ckpt = torch.load(checkpoint_dir / f"{model_class}/state_only.ckpt") model.load_state_dict(ckpt) model.eval() for p in model.parameters(): p.requires_grad_(False) model = model.to(device) return model def run_inference(prompt: str): """The entry point of the demo.""" device: torch.device = torch.device("cuda") """Device to run the demo on.""" seed: int = 63 """Random seed for reproducibility.""" # set random seed seed_everything(seed) # load SPAGHETTI and mesher spaghetti = load_spaghetti(device) mesher = load_mesher(device) # load SALAD lang_phase1_model = load_model("lang_phase1", device) lang_phase2_model = load_model("phase2", device) lang_phase1_model._build_dataset("val") # run phase 1 extrinsics = lang_phase1_model.sampling_gaussians([prompt]) # run phase 2 intrinsics = lang_phase2_model.sample(extrinsics) # generate mesh zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics) vertices, faces = get_mesh_from_spaghetti( spaghetti, mesher, zcs[0], res=256, ) # plot figure = go.Figure( data=[ go.Mesh3d( x=vertices[:, 0], # flip front-back y=-vertices[:, 2], z=vertices[:, 1], i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], color="gray", ) ], layout=dict( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), ) ), ) return figure if __name__ == "__main__": # create UI demo = gr.Interface( fn=run_inference, inputs="text", outputs=gr.Plot(), title="SALAD: Text-Guided Shape Generation", description="Describe a chair", examples=[ "an office chair", "a chair with armrests", "a chair without armrests", ] ) # initiate demo.queue(max_size=30) demo.launch()