File size: 2,851 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
app.py

An interactive demo of text-guided shape generation.
"""

from pathlib import Path
from typing import Literal

import gradio as gr
import plotly.graph_objects as go

from salad.utils.spaghetti_util import (
    get_mesh_from_spaghetti,
    generate_zc_from_sj_gaus,
    load_mesher,
    load_spaghetti,
)
import hydra
from omegaconf import OmegaConf
import torch
from pytorch_lightning import seed_everything


def load_model(
    model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"],
    device,
):
    checkpoint_dir = Path(__file__).parent / "checkpoints"
    c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
    model = hydra.utils.instantiate(c)
    ckpt = torch.load(checkpoint_dir / f"{model_class}/state_only.ckpt")
    model.load_state_dict(ckpt)
    model.eval()
    for p in model.parameters(): p.requires_grad_(False)
    model = model.to(device)
    return model


def run_inference(prompt: str):
    """The entry point of the demo."""

    device: torch.device = torch.device("cuda")
    """Device to run the demo on."""
    seed: int = 63
    """Random seed for reproducibility."""

    # set random seed
    seed_everything(seed)

    # load SPAGHETTI and mesher
    spaghetti = load_spaghetti(device)
    mesher = load_mesher(device)

    # load SALAD
    lang_phase1_model = load_model("lang_phase1", device)
    lang_phase2_model = load_model("phase2", device)
    lang_phase1_model._build_dataset("val")

    # run phase 1
    extrinsics = lang_phase1_model.sampling_gaussians([prompt])

    # run phase 2
    intrinsics = lang_phase2_model.sample(extrinsics)

    # generate mesh
    zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics)
    vertices, faces = get_mesh_from_spaghetti(
        spaghetti,
        mesher,
        zcs[0],
        res=256,
    )

    # plot
    figure = go.Figure(
        data=[
            go.Mesh3d(
                x=vertices[:, 0],  # flip front-back
                y=-vertices[:, 2],
                z=vertices[:, 1],
                i=faces[:, 0],
                j=faces[:, 1],
                k=faces[:, 2],
                color="gray",
            )
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
            )
        ),
    )

    return figure

if __name__ == "__main__":

    # create UI
    demo = gr.Interface(
        fn=run_inference,
        inputs="text",
        outputs=gr.Plot(),
        title="SALAD: Text-Guided Shape Generation",
        description="Describe a chair",
        examples=[
            "an office chair",
            "a chair with armrests",
            "a chair without armrests",
        ]
    )
    # initiate
    demo.queue(max_size=30)
    demo.launch()