File size: 3,873 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddec8ca
 
 
 
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fa8725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801501a
1fa8725
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
app.py

An interactive demo of text-guided shape generation.
"""

from pathlib import Path
from typing import Literal

import gradio as gr
import plotly.graph_objects as go

from salad.utils.spaghetti_util import (
    get_mesh_from_spaghetti,
    generate_zc_from_sj_gaus,
    load_mesher,
    load_spaghetti,
)
import hydra
from omegaconf import OmegaConf
import torch
from pytorch_lightning import seed_everything


def load_model(
    model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"],
    device,
):
    checkpoint_dir = Path(__file__).parent / "checkpoints"
    c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
    model = hydra.utils.instantiate(c)
    ckpt = torch.load(
        checkpoint_dir / f"{model_class}/state_only.ckpt",
        map_location=device,
    )
    model.load_state_dict(ckpt)
    model.eval()
    for p in model.parameters(): p.requires_grad_(False)
    model = model.to(device)
    return model


def run_inference(prompt: str):
    """The entry point of the demo."""

    device: torch.device = torch.device("cuda")
    """Device to run the demo on."""
    seed: int = 63
    """Random seed for reproducibility."""

    # set random seed
    seed_everything(seed)

    # load SPAGHETTI and mesher
    spaghetti = load_spaghetti(device)
    mesher = load_mesher(device)

    # load SALAD
    lang_phase1_model = load_model("lang_phase1", device)
    lang_phase2_model = load_model("phase2", device)
    lang_phase1_model._build_dataset("val")

    # run phase 1
    extrinsics = lang_phase1_model.sampling_gaussians([prompt])

    # run phase 2
    intrinsics = lang_phase2_model.sample(extrinsics)

    # generate mesh
    zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics)
    vertices, faces = get_mesh_from_spaghetti(
        spaghetti,
        mesher,
        zcs[0],
        res=256,
    )

    # plot
    figure = go.Figure(
        data=[
            go.Mesh3d(
                x=vertices[:, 0],  # flip front-back
                y=-vertices[:, 2],
                z=vertices[:, 1],
                i=faces[:, 0],
                j=faces[:, 1],
                k=faces[:, 2],
                color="gray",
            )
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
            )
        ),
    )

    return figure

if __name__ == "__main__":

    title = "SALAD: Text-Guided Shape Generation"

    description_text = '''
    This demo features text-guided 3D shape generation from our work <a href="https://arxiv.org/abs/2303.12236">SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023</a>.  
    Please refer to our <a href="https://salad3d.github.io/">project page</a> for details.
    '''

    # create UI        
    with gr.Blocks(title=title) as demo:

        # description of demo
        gr.Markdown(description_text)

        # inputs
        with gr.Row():
            prompt_textbox = gr.Textbox(placeholder="Describe a chair.")

            with gr.Row():
                run_button = gr.Button(value="Generate")
                clear_button = gr.ClearButton(
                    value="Clear",
                    components=[prompt_textbox],
                )

        # display examples
        examples = gr.Examples(
            examples=[
                "an office chair",
                "a chair with armrests",
                "a chair without armrests",
            ],
            inputs=[prompt_textbox],
        )

        # outputs
        mesh_viewport = gr.Plot()

        # run inference
        run_button.click(
            run_inference,
            inputs=[prompt_textbox],
            outputs=[mesh_viewport],
        )

    demo.queue(max_size=30)
    demo.launch()