Spaces:
Sleeping
Sleeping
File size: 3,873 Bytes
801501a ddec8ca 801501a 1fa8725 801501a 1fa8725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
"""
app.py
An interactive demo of text-guided shape generation.
"""
from pathlib import Path
from typing import Literal
import gradio as gr
import plotly.graph_objects as go
from salad.utils.spaghetti_util import (
get_mesh_from_spaghetti,
generate_zc_from_sj_gaus,
load_mesher,
load_spaghetti,
)
import hydra
from omegaconf import OmegaConf
import torch
from pytorch_lightning import seed_everything
def load_model(
model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"],
device,
):
checkpoint_dir = Path(__file__).parent / "checkpoints"
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
model = hydra.utils.instantiate(c)
ckpt = torch.load(
checkpoint_dir / f"{model_class}/state_only.ckpt",
map_location=device,
)
model.load_state_dict(ckpt)
model.eval()
for p in model.parameters(): p.requires_grad_(False)
model = model.to(device)
return model
def run_inference(prompt: str):
"""The entry point of the demo."""
device: torch.device = torch.device("cuda")
"""Device to run the demo on."""
seed: int = 63
"""Random seed for reproducibility."""
# set random seed
seed_everything(seed)
# load SPAGHETTI and mesher
spaghetti = load_spaghetti(device)
mesher = load_mesher(device)
# load SALAD
lang_phase1_model = load_model("lang_phase1", device)
lang_phase2_model = load_model("phase2", device)
lang_phase1_model._build_dataset("val")
# run phase 1
extrinsics = lang_phase1_model.sampling_gaussians([prompt])
# run phase 2
intrinsics = lang_phase2_model.sample(extrinsics)
# generate mesh
zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics)
vertices, faces = get_mesh_from_spaghetti(
spaghetti,
mesher,
zcs[0],
res=256,
)
# plot
figure = go.Figure(
data=[
go.Mesh3d(
x=vertices[:, 0], # flip front-back
y=-vertices[:, 2],
z=vertices[:, 1],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
color="gray",
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
)
),
)
return figure
if __name__ == "__main__":
title = "SALAD: Text-Guided Shape Generation"
description_text = '''
This demo features text-guided 3D shape generation from our work <a href="https://arxiv.org/abs/2303.12236">SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023</a>.
Please refer to our <a href="https://salad3d.github.io/">project page</a> for details.
'''
# create UI
with gr.Blocks(title=title) as demo:
# description of demo
gr.Markdown(description_text)
# inputs
with gr.Row():
prompt_textbox = gr.Textbox(placeholder="Describe a chair.")
with gr.Row():
run_button = gr.Button(value="Generate")
clear_button = gr.ClearButton(
value="Clear",
components=[prompt_textbox],
)
# display examples
examples = gr.Examples(
examples=[
"an office chair",
"a chair with armrests",
"a chair without armrests",
],
inputs=[prompt_textbox],
)
# outputs
mesh_viewport = gr.Plot()
# run inference
run_button.click(
run_inference,
inputs=[prompt_textbox],
outputs=[mesh_viewport],
)
demo.queue(max_size=30)
demo.launch() |