salad-demo / app.py
DveloperY0115's picture
Display demo description
1fa8725
"""
app.py
An interactive demo of text-guided shape generation.
"""
from pathlib import Path
from typing import Literal
import gradio as gr
import plotly.graph_objects as go
from salad.utils.spaghetti_util import (
get_mesh_from_spaghetti,
generate_zc_from_sj_gaus,
load_mesher,
load_spaghetti,
)
import hydra
from omegaconf import OmegaConf
import torch
from pytorch_lightning import seed_everything
def load_model(
model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"],
device,
):
checkpoint_dir = Path(__file__).parent / "checkpoints"
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
model = hydra.utils.instantiate(c)
ckpt = torch.load(
checkpoint_dir / f"{model_class}/state_only.ckpt",
map_location=device,
)
model.load_state_dict(ckpt)
model.eval()
for p in model.parameters(): p.requires_grad_(False)
model = model.to(device)
return model
def run_inference(prompt: str):
"""The entry point of the demo."""
device: torch.device = torch.device("cuda")
"""Device to run the demo on."""
seed: int = 63
"""Random seed for reproducibility."""
# set random seed
seed_everything(seed)
# load SPAGHETTI and mesher
spaghetti = load_spaghetti(device)
mesher = load_mesher(device)
# load SALAD
lang_phase1_model = load_model("lang_phase1", device)
lang_phase2_model = load_model("phase2", device)
lang_phase1_model._build_dataset("val")
# run phase 1
extrinsics = lang_phase1_model.sampling_gaussians([prompt])
# run phase 2
intrinsics = lang_phase2_model.sample(extrinsics)
# generate mesh
zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics)
vertices, faces = get_mesh_from_spaghetti(
spaghetti,
mesher,
zcs[0],
res=256,
)
# plot
figure = go.Figure(
data=[
go.Mesh3d(
x=vertices[:, 0], # flip front-back
y=-vertices[:, 2],
z=vertices[:, 1],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
color="gray",
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
)
),
)
return figure
if __name__ == "__main__":
title = "SALAD: Text-Guided Shape Generation"
description_text = '''
This demo features text-guided 3D shape generation from our work <a href="https://arxiv.org/abs/2303.12236">SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023</a>.
Please refer to our <a href="https://salad3d.github.io/">project page</a> for details.
'''
# create UI
with gr.Blocks(title=title) as demo:
# description of demo
gr.Markdown(description_text)
# inputs
with gr.Row():
prompt_textbox = gr.Textbox(placeholder="Describe a chair.")
with gr.Row():
run_button = gr.Button(value="Generate")
clear_button = gr.ClearButton(
value="Clear",
components=[prompt_textbox],
)
# display examples
examples = gr.Examples(
examples=[
"an office chair",
"a chair with armrests",
"a chair without armrests",
],
inputs=[prompt_textbox],
)
# outputs
mesh_viewport = gr.Plot()
# run inference
run_button.click(
run_inference,
inputs=[prompt_textbox],
outputs=[mesh_viewport],
)
demo.queue(max_size=30)
demo.launch()