Maikou commited on
Commit
5b7e52a
·
1 Parent(s): b621857

Add application file

Browse files
Files changed (1) hide show
  1. app.py +372 -0
app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+ from PIL import Image
6
+ import torch
7
+ import trimesh
8
+ from typing import Optional, List
9
+ from einops import repeat, rearrange
10
+ import numpy as np
11
+ from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
12
+ from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
13
+ from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
14
+ from michelangelo.utils.visualizers import html_util
15
+
16
+ import gradio as gr
17
+
18
+
19
+ gradio_cached_dir = "./gradio_cached_dir"
20
+ os.makedirs(gradio_cached_dir, exist_ok=True)
21
+
22
+ save_mesh = False
23
+
24
+ state = ""
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+ box_v = 1.1
28
+ viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
29
+
30
+ image_model_config_dict = OrderedDict({
31
+ "ASLDM-256-obj": {
32
+ "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml",
33
+ "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt",
34
+ },
35
+ })
36
+
37
+ text_model_config_dict = OrderedDict({
38
+ "ASLDM-256": {
39
+ "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml",
40
+ "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt",
41
+ },
42
+ })
43
+
44
+
45
+ class InferenceModel(object):
46
+ model = None
47
+ name = ""
48
+
49
+
50
+ text2mesh_model = InferenceModel()
51
+ image2mesh_model = InferenceModel()
52
+
53
+
54
+ def set_state(s):
55
+ global state
56
+ state = s
57
+ print(s)
58
+
59
+
60
+ def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float,
61
+ image: Optional[np.ndarray] = None,
62
+ html_frame: bool = False):
63
+ global viewer
64
+
65
+ for i in range(len(mesh_outputs)):
66
+ mesh = mesh_outputs[i]
67
+ if mesh is None:
68
+ continue
69
+
70
+ mesh_v = mesh.mesh_v.copy()
71
+ mesh_v[:, 0] += i * np.max(bbox_size)
72
+ mesh_v[:, 2] += np.max(bbox_size)
73
+ viewer.add_mesh(mesh_v, mesh.mesh_f)
74
+
75
+ mesh_tag = viewer.to_html(html_frame=False)
76
+
77
+ if image is not None:
78
+ image_tag = html_util.to_image_embed_tag(image)
79
+ frame = f"""
80
+ <table border = "1">
81
+ <tr>
82
+ <td>{image_tag}</td>
83
+ <td>{mesh_tag}</td>
84
+ </tr>
85
+ </table>
86
+ """
87
+ else:
88
+ frame = mesh_tag
89
+
90
+ if html_frame:
91
+ frame = html_util.to_html_frame(frame)
92
+
93
+ viewer.reset()
94
+
95
+ return frame
96
+
97
+
98
+ def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel):
99
+ global device
100
+
101
+ if inference_model.name == model_name:
102
+ model = inference_model.model
103
+ else:
104
+ assert model_name in model_config_dict
105
+
106
+ if inference_model.model is not None:
107
+ del inference_model.model
108
+
109
+ config_ckpt_path = model_config_dict[model_name]
110
+
111
+ model_config = get_config_from_file(config_ckpt_path["config"])
112
+ if hasattr(model_config, "model"):
113
+ model_config = model_config.model
114
+
115
+ model = instantiate_from_config(model_config, ckpt_path=config_ckpt_path["ckpt_path"])
116
+ model = model.to(device)
117
+ model = model.eval()
118
+
119
+ inference_model.model = model
120
+ inference_model.name = model_name
121
+
122
+ return model
123
+
124
+
125
+ def prepare_img(image: np.ndarray):
126
+ image_pt = torch.tensor(image).float()
127
+ image_pt = image_pt / 255 * 2 - 1
128
+ image_pt = rearrange(image_pt, "h w c -> c h w")
129
+
130
+ return image_pt
131
+
132
+ def prepare_model_viewer(fp):
133
+ content = f"""
134
+ <head>
135
+ <script
136
+ type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.1.1/model-viewer.min.js">
137
+ </script>
138
+ </head>
139
+ <body>
140
+ <model-viewer
141
+ style="height: 150px; width: 150px;"
142
+ rotation-per-second="10deg"
143
+ id="t1"
144
+ src="file/gradio_cached_dir/{fp}"
145
+ environment-image="neutral"
146
+ camera-target="0m 0m 0m"
147
+ orientation="0deg 90deg 170deg"
148
+ shadow-intensity="1"
149
+ ar:true
150
+ auto-rotate
151
+ camera-controls>
152
+ </model-viewer>
153
+ </body>
154
+ """
155
+ return content
156
+
157
+ def prepare_html_frame(content):
158
+ frame = f"""
159
+ <html>
160
+ <body>
161
+ {content}
162
+ </body>
163
+ </html>
164
+ """
165
+ return frame
166
+
167
+ def prepare_html_body(content):
168
+ frame = f"""
169
+ <body>
170
+ {content}
171
+ </body>
172
+ """
173
+ return frame
174
+
175
+ def post_process_mesh_outputs(mesh_outputs):
176
+ # html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True)
177
+ html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False)
178
+ html_frame = prepare_html_frame(html_content)
179
+
180
+ # filename = f"{time.time()}.html"
181
+ filename = f"text-256-{time.time()}.html"
182
+ html_filepath = os.path.join(gradio_cached_dir, filename)
183
+ with open(html_filepath, "w") as writer:
184
+ writer.write(html_frame)
185
+
186
+ '''
187
+ Bug: The iframe tag does not work in Gradio.
188
+ The chrome returns "No resource with given URL found"
189
+ Solutions:
190
+ https://github.com/gradio-app/gradio/issues/884
191
+ Due to the security bitches, the server can only find files parallel to the gradio_app.py.
192
+ The path has format "file/TARGET_FILE_PATH"
193
+ '''
194
+
195
+ iframe_tag = f'<iframe src="file/gradio_cached_dir/{filename}" width="600%" height="400" frameborder="0"></iframe>'
196
+
197
+ filelist = []
198
+ filenames = []
199
+ for i, mesh in enumerate(mesh_outputs):
200
+ mesh.mesh_f = mesh.mesh_f[:, ::-1]
201
+ mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
202
+
203
+ name = str(i) + "_out_mesh.obj"
204
+ filepath = gradio_cached_dir + "/" + name
205
+ mesh_output.export(filepath, include_normals=True)
206
+ filelist.append(filepath)
207
+ filenames.append(name)
208
+
209
+ filelist.append(html_filepath)
210
+ return iframe_tag, filelist
211
+
212
+ def image2mesh(image: np.ndarray,
213
+ model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
214
+ num_samples: int = 4,
215
+ guidance_scale: int = 7.5,
216
+ octree_depth: int = 7):
217
+ global device, gradio_cached_dir, image_model_config_dict, box_v
218
+
219
+ # load model
220
+ model = load_model(model_name, image_model_config_dict, image2mesh_model)
221
+
222
+ # prepare image inputs
223
+ image_pt = prepare_img(image)
224
+ image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples)
225
+
226
+ sample_inputs = {
227
+ "image": image_pt
228
+ }
229
+ mesh_outputs = model.sample(
230
+ sample_inputs,
231
+ sample_times=1,
232
+ guidance_scale=guidance_scale,
233
+ return_intermediates=False,
234
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
235
+ octree_depth=octree_depth,
236
+ )[0]
237
+
238
+ iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
239
+
240
+ return iframe_tag, gr.update(value=filelist, visible=True)
241
+
242
+
243
+ def text2mesh(text: str,
244
+ model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
245
+ num_samples: int = 4,
246
+ guidance_scale: int = 7.5,
247
+ octree_depth: int = 7):
248
+ global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v
249
+
250
+ # load model
251
+ model = load_model(model_name, text_model_config_dict, text2mesh_model)
252
+
253
+ # prepare text inputs
254
+ sample_inputs = {
255
+ "text": [text] * num_samples
256
+ }
257
+ mesh_outputs = model.sample(
258
+ sample_inputs,
259
+ sample_times=1,
260
+ guidance_scale=guidance_scale,
261
+ return_intermediates=False,
262
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
263
+ octree_depth=octree_depth,
264
+ )[0]
265
+
266
+ iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
267
+
268
+ return iframe_tag, gr.update(value=filelist, visible=True)
269
+
270
+ example_dir = './gradio_cached_dir/example/img_example'
271
+
272
+ first_page_items = [
273
+ 'alita.jpg',
274
+ 'burger.jpg'
275
+ 'loopy.jpg'
276
+ 'building.jpg',
277
+ 'mario.jpg',
278
+ 'car.jpg',
279
+ 'airplane.jpg',
280
+ 'bag.jpg',
281
+ 'bench.jpg',
282
+ 'ship.jpg'
283
+ ]
284
+ raw_example_items = [
285
+ # (os.path.join(example_dir, x), x)
286
+ os.path.join(example_dir, x)
287
+ for x in os.listdir(example_dir)
288
+ if x.endswith(('.jpg', '.png'))
289
+ ]
290
+ example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items]
291
+
292
+ example_text = [
293
+ ["A 3D model of a car; Audi A6."],
294
+ ["A 3D model of police car; Highway Patrol Charger"]
295
+ ],
296
+
297
+ def set_cache(data: gr.SelectData):
298
+ img_name = os.path.basename(example_items[data.index])
299
+ return os.path.join(example_dir, img_name), os.path.join(img_name)
300
+
301
+ def disable_cache():
302
+ return ""
303
+
304
+ with gr.Blocks() as app:
305
+ gr.Markdown("# Michelangelo")
306
+ gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)")
307
+ gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.")
308
+ gr.Markdown("### Hint:")
309
+ gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation")
310
+ gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse")
311
+ gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.")
312
+ gr.Markdown("4. Welcome to share your amazing results with us, and thanks for your interest in our work!")
313
+
314
+ with gr.Row():
315
+ with gr.Column():
316
+
317
+ with gr.Tab("Image to 3D"):
318
+ img = gr.Image(label="Image")
319
+ gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.")
320
+ btn_generate_img2obj = gr.Button(value="Generate")
321
+
322
+ with gr.Accordion("Advanced settings", open=False):
323
+ image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys()))
324
+ num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
325
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
326
+ octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
327
+
328
+
329
+ cache_dir = gr.Textbox(value="", visible=False)
330
+ examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain")
331
+
332
+ with gr.Tab("Text to 3D"):
333
+ prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.")
334
+ gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.")
335
+ btn_generate_txt2obj = gr.Button(value="Generate")
336
+
337
+ with gr.Accordion("Advanced settings", open=False):
338
+ text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys()))
339
+ num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
340
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
341
+ octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
342
+
343
+ gr.Markdown("#### Examples:")
344
+ gr.Markdown("1. A 3D model of a coupe; Audi A6.")
345
+ gr.Markdown("2. A 3D model of a motorcar; Hummer H2 SUT.")
346
+ gr.Markdown("3. A 3D model of an airplane; Airbus.")
347
+ gr.Markdown("4. A 3D model of a fighter aircraft; Attack Fighter.")
348
+ gr.Markdown("5. A 3D model of a chair; Simple Wooden Chair.")
349
+ gr.Markdown("6. A 3D model of a laptop computer; Dell Laptop.")
350
+ gr.Markdown("7. A 3D model of a lamp; ceiling light.")
351
+ gr.Markdown("8. A 3D model of a rifle; AK47.")
352
+ gr.Markdown("9. A 3D model of a knife; Sword.")
353
+ gr.Markdown("10. A 3D model of a vase; Plant in pot.")
354
+
355
+ with gr.Column():
356
+ model_3d = gr.HTML()
357
+ file_out = gr.File(label="Files", visible=False)
358
+
359
+ outputs = [model_3d, file_out]
360
+
361
+ img.upload(disable_cache, outputs=cache_dir)
362
+ examples.select(set_cache, outputs=[img, cache_dir])
363
+ print(f'line:404: {cache_dir}')
364
+ btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples,
365
+ guidance_scale, octree_depth],
366
+ outputs=outputs, api_name="generate_img2obj")
367
+
368
+ btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples,
369
+ guidance_scale, octree_depth],
370
+ outputs=outputs, api_name="generate_txt2obj")
371
+
372
+ app.launch(server_name="0.0.0.0", server_port=8008, share=False)