Spaces:
Sleeping
Sleeping
""" | |
app.py | |
An interactive demo of text-guided shape generation. | |
""" | |
from pathlib import Path | |
from typing import Literal | |
import gradio as gr | |
import plotly.graph_objects as go | |
from salad.utils.spaghetti_util import ( | |
get_mesh_from_spaghetti, | |
generate_zc_from_sj_gaus, | |
load_mesher, | |
load_spaghetti, | |
) | |
import hydra | |
from omegaconf import OmegaConf | |
import torch | |
from pytorch_lightning import seed_everything | |
def load_model( | |
model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"], | |
device, | |
): | |
checkpoint_dir = Path(__file__).parent / "checkpoints" | |
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml") | |
model = hydra.utils.instantiate(c) | |
ckpt = torch.load( | |
checkpoint_dir / f"{model_class}/state_only.ckpt", | |
map_location=device, | |
) | |
model.load_state_dict(ckpt) | |
model.eval() | |
for p in model.parameters(): p.requires_grad_(False) | |
model = model.to(device) | |
return model | |
def run_inference(prompt: str): | |
"""The entry point of the demo.""" | |
device: torch.device = torch.device("cuda") | |
"""Device to run the demo on.""" | |
seed: int = 63 | |
"""Random seed for reproducibility.""" | |
# set random seed | |
seed_everything(seed) | |
# load SPAGHETTI and mesher | |
spaghetti = load_spaghetti(device) | |
mesher = load_mesher(device) | |
# load SALAD | |
lang_phase1_model = load_model("lang_phase1", device) | |
lang_phase2_model = load_model("phase2", device) | |
lang_phase1_model._build_dataset("val") | |
# run phase 1 | |
extrinsics = lang_phase1_model.sampling_gaussians([prompt]) | |
# run phase 2 | |
intrinsics = lang_phase2_model.sample(extrinsics) | |
# generate mesh | |
zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics) | |
vertices, faces = get_mesh_from_spaghetti( | |
spaghetti, | |
mesher, | |
zcs[0], | |
res=256, | |
) | |
# plot | |
figure = go.Figure( | |
data=[ | |
go.Mesh3d( | |
x=vertices[:, 0], # flip front-back | |
y=-vertices[:, 2], | |
z=vertices[:, 1], | |
i=faces[:, 0], | |
j=faces[:, 1], | |
k=faces[:, 2], | |
color="gray", | |
) | |
], | |
layout=dict( | |
scene=dict( | |
xaxis=dict(visible=False), | |
yaxis=dict(visible=False), | |
zaxis=dict(visible=False), | |
) | |
), | |
) | |
return figure | |
if __name__ == "__main__": | |
title = "SALAD: Text-Guided Shape Generation" | |
description_text = ''' | |
This demo features text-guided 3D shape generation from our work <a href="https://arxiv.org/abs/2303.12236">SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023</a>. | |
Please refer to our <a href="https://salad3d.github.io/">project page</a> for details. | |
''' | |
# create UI | |
with gr.Blocks(title=title) as demo: | |
# description of demo | |
gr.Markdown(description_text) | |
# inputs | |
with gr.Row(): | |
prompt_textbox = gr.Textbox(placeholder="Describe a chair.") | |
with gr.Row(): | |
run_button = gr.Button(value="Generate") | |
clear_button = gr.ClearButton( | |
value="Clear", | |
components=[prompt_textbox], | |
) | |
# display examples | |
examples = gr.Examples( | |
examples=[ | |
"an office chair", | |
"a chair with armrests", | |
"a chair without armrests", | |
], | |
inputs=[prompt_textbox], | |
) | |
# outputs | |
mesh_viewport = gr.Plot() | |
# run inference | |
run_button.click( | |
run_inference, | |
inputs=[prompt_textbox], | |
outputs=[mesh_viewport], | |
) | |
demo.queue(max_size=30) | |
demo.launch() |