Zhenyu Li commited on
Commit
abbda4e
1 Parent(s): 24b9846
Files changed (1) hide show
  1. ui_prediction.py +347 -0
ui_prediction.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Zhenyu Li
24
+
25
+ import gradio as gr
26
+ from PIL import Image
27
+ import tempfile
28
+ import torch
29
+ import numpy as np
30
+
31
+ from zoedepth.utils.arg_utils import parse_unknown
32
+ import argparse
33
+ from zoedepth.models.builder import build_model
34
+ from zoedepth.utils.config import get_config_user
35
+ import matplotlib
36
+ import cv2
37
+
38
+ from infer_user import regular_tile_param, random_tile_param
39
+ from zoedepth.models.base_models.midas import Resize
40
+ from torchvision.transforms import Compose
41
+ from PIL import Image
42
+ from torchvision import transforms
43
+ import torch.nn.functional as F
44
+
45
+ from zoedepth.models.base_models.midas import Resize
46
+ from torchvision.transforms import Compose
47
+
48
+ import gradio as gr
49
+ import numpy as np
50
+ import trimesh
51
+ from zoedepth.utils.geometry import depth_to_points, create_triangles
52
+ from functools import partial
53
+ import tempfile
54
+
55
+ def depth_edges_mask(depth, occ_filter_thr):
56
+ """Returns a mask of edges in the depth map.
57
+ Args:
58
+ depth: 2D numpy array of shape (H, W) with dtype float32.
59
+ Returns:
60
+ mask: 2D numpy array of shape (H, W) with dtype bool.
61
+ """
62
+ # Compute the x and y gradients of the depth map.
63
+ depth_dx, depth_dy = np.gradient(depth)
64
+ # Compute the gradient magnitude.
65
+ depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
66
+ # Compute the edge mask.
67
+ # mask = depth_grad > 0.05 # default in zoedepth
68
+ mask = depth_grad > occ_filter_thr # preserve more edges (?)
69
+ return mask
70
+
71
+ def load_state_dict(model, state_dict):
72
+ """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict.
73
+
74
+ DataParallel prefixes state_dict keys with 'module.' when saving.
75
+ If the model is not a DataParallel model but the state_dict is, then prefixes are removed.
76
+ If the model is a DataParallel model but the state_dict is not, then prefixes are added.
77
+ """
78
+ state_dict = state_dict.get('model', state_dict)
79
+ # if model is a DataParallel model, then state_dict keys are prefixed with 'module.'
80
+
81
+ do_prefix = isinstance(
82
+ model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
83
+ state = {}
84
+ for k, v in state_dict.items():
85
+ if k.startswith('module.') and not do_prefix:
86
+ k = k[7:]
87
+
88
+ if not k.startswith('module.') and do_prefix:
89
+ k = 'module.' + k
90
+
91
+ state[k] = v
92
+
93
+ model.load_state_dict(state, strict=True)
94
+ print("Loaded successfully")
95
+ return model
96
+
97
+ def load_wts(model, checkpoint_path):
98
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
99
+ return load_state_dict(model, ckpt)
100
+
101
+ def load_ckpt(model, checkpoint):
102
+ model = load_wts(model, checkpoint)
103
+ print("Loaded weights from {0}".format(checkpoint))
104
+ return model
105
+
106
+ def colorize(value, cmap='magma_r', vmin=None, vmax=None):
107
+ # normalize
108
+ vmin = value.min() if vmin is None else vmin
109
+ # vmax = value.max() if vmax is None else vmax
110
+ vmax = np.percentile(value, 95) if vmax is None else vmax
111
+
112
+ if vmin != vmax:
113
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
114
+ else:
115
+ value = value * 0.
116
+
117
+ cmapper = matplotlib.cm.get_cmap(cmap)
118
+ value = cmapper(value, bytes=True) # ((1)xhxwx4)
119
+
120
+ value = value[:, :, :3] # bgr -> rgb
121
+ # rgb_value = value[..., ::-1]
122
+ rgb_value = value
123
+
124
+ return rgb_value
125
+
126
+ def predict_depth(model, image, mode, pn, reso, ps, device=None):
127
+
128
+ pil_image = image
129
+ if device is not None:
130
+ image = transforms.ToTensor()(pil_image).unsqueeze(0).to(device)
131
+ else:
132
+ image = transforms.ToTensor()(pil_image).unsqueeze(0).cuda()
133
+
134
+ image_height, image_width = image.shape[-2], image.shape[-1]
135
+
136
+ if reso != '':
137
+ image_resolution = (int(reso.split('x')[0]), int(reso.split('x')[1]))
138
+ else:
139
+ image_resolution = (2160, 3840)
140
+ image_hr = F.interpolate(image, image_resolution, mode='bicubic', align_corners=True)
141
+ preprocess = Compose([Resize(512, 384, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")])
142
+ image_lr = preprocess(image)
143
+
144
+ if ps != '':
145
+ patch_size = (int(ps.split('x')[0]), int(ps.split('x')[1]))
146
+ else:
147
+ patch_size = (int(image_resolution[0] // 4), int(image_resolution[1] // 4))
148
+
149
+ avg_depth_map = regular_tile_param(
150
+ model,
151
+ image_hr,
152
+ offset_x=0,
153
+ offset_y=0,
154
+ img_lr=image_lr,
155
+ crop_size=patch_size,
156
+ img_resolution=image_resolution,
157
+ transform=preprocess,
158
+ blr_mask=True)
159
+
160
+ if mode== 'P16':
161
+ pass
162
+ elif mode== 'P49':
163
+ regular_tile_param(
164
+ model,
165
+ image_hr,
166
+ offset_x=patch_size[1]//2,
167
+ offset_y=0,
168
+ img_lr=image_lr,
169
+ iter_pred=avg_depth_map.average_map,
170
+ boundary=0,
171
+ update=True,
172
+ avg_depth_map=avg_depth_map,
173
+ crop_size=patch_size,
174
+ img_resolution=image_resolution,
175
+ transform=preprocess,
176
+ blr_mask=True)
177
+ regular_tile_param(
178
+ model,
179
+ image_hr,
180
+ offset_x=0,
181
+ offset_y=patch_size[0]//2,
182
+ img_lr=image_lr,
183
+ iter_pred=avg_depth_map.average_map,
184
+ boundary=0,
185
+ update=True,
186
+ avg_depth_map=avg_depth_map,
187
+ crop_size=patch_size,
188
+ img_resolution=image_resolution,
189
+ transform=preprocess,
190
+ blr_mask=True)
191
+ regular_tile_param(
192
+ model,
193
+ image_hr,
194
+ offset_x=patch_size[1]//2,
195
+ offset_y=patch_size[0]//2,
196
+ img_lr=image_lr,
197
+ iter_pred=avg_depth_map.average_map,
198
+ boundary=0,
199
+ update=True,
200
+ avg_depth_map=avg_depth_map,
201
+ crop_size=patch_size,
202
+ img_resolution=image_resolution,
203
+ transform=preprocess,
204
+ blr_mask=True)
205
+ elif mode == 'R':
206
+ regular_tile_param(
207
+ model,
208
+ image_hr,
209
+ offset_x=patch_size[1]//2,
210
+ offset_y=0,
211
+ img_lr=image_lr,
212
+ iter_pred=avg_depth_map.average_map,
213
+ boundary=0,
214
+ update=True,
215
+ avg_depth_map=avg_depth_map,
216
+ crop_size=patch_size,
217
+ img_resolution=image_resolution,
218
+ transform=preprocess,
219
+ blr_mask=True)
220
+ regular_tile_param(
221
+ model,
222
+ image_hr,
223
+ offset_x=0,
224
+ offset_y=patch_size[0]//2,
225
+ img_lr=image_lr,
226
+ iter_pred=avg_depth_map.average_map,
227
+ boundary=0,
228
+ update=True,
229
+ avg_depth_map=avg_depth_map,
230
+ crop_size=patch_size,
231
+ img_resolution=image_resolution,
232
+ transform=preprocess,
233
+ blr_mask=True)
234
+ regular_tile_param(
235
+ model,
236
+ image_hr,
237
+ offset_x=patch_size[1]//2,
238
+ offset_y=patch_size[0]//2,
239
+ img_lr=image_lr,
240
+ iter_pred=avg_depth_map.average_map,
241
+ boundary=0,
242
+ update=True,
243
+ avg_depth_map=avg_depth_map,
244
+ crop_size=patch_size,
245
+ img_resolution=image_resolution,
246
+ transform=preprocess,
247
+ blr_mask=True)
248
+
249
+ for i in range(int(pn)):
250
+ random_tile_param(
251
+ model,
252
+ image_hr,
253
+ img_lr=image_lr,
254
+ iter_pred=avg_depth_map.average_map,
255
+ boundary=0,
256
+ update=True,
257
+ avg_depth_map=avg_depth_map,
258
+ crop_size=patch_size,
259
+ img_resolution=image_resolution,
260
+ transform=preprocess,
261
+ blr_mask=True)
262
+
263
+ depth = avg_depth_map.average_map.detach().cpu()
264
+ depth = F.interpolate(depth.unsqueeze(dim=0).unsqueeze(dim=0), (image_height, image_width), mode='bicubic', align_corners=True).squeeze().numpy()
265
+
266
+ return depth
267
+
268
+ def create_demo(model):
269
+ gr.Markdown("## Depth Prediction Demo")
270
+
271
+ with gr.Accordion("Advanced options", open=False):
272
+ mode = gr.Radio(["P49", "R"], label="Tiling mode", info="We recommand using P49 for fast evaluation and R with 1024 patches for best visualization results, respectively", elem_id='mode', value='R'),
273
+ patch_number = gr.Slider(1, 1024, label="Please decide the number of random patches (Only useful in mode=R)", step=1, value=256)
274
+ resolution = gr.Textbox(label="Proccessing resolution (Default 4K. Use 'x' to split height and width.)", elem_id='mode', value='2160x3840')
275
+ patch_size = gr.Textbox(label="Patch size (Default 1/4 of image resolution. Use 'x' to split height and width.)", elem_id='mode', value='540x960')
276
+
277
+ with gr.Row():
278
+ input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
279
+ depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
280
+ raw_file = gr.File(label="16-bit raw depth, multiplier:256")
281
+ submit = gr.Button("Submit")
282
+
283
+ def on_submit(image, mode, pn, reso, ps):
284
+ depth = predict_depth(model, image, mode, pn, reso, ps)
285
+ colored_depth = colorize(depth, cmap='gray_r')
286
+ tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
287
+ raw_depth = Image.fromarray((depth*256).astype('uint16'))
288
+ raw_depth.save(tmp.name)
289
+ return [colored_depth, tmp.name]
290
+
291
+ submit.click(on_submit, inputs=[input_image, mode[0], patch_number, resolution, patch_size], outputs=[depth_image, raw_file])
292
+ examples = gr.Examples(examples=["examples/example_1.jpeg", "examples/example_2.jpeg", "examples/example_3.jpeg"], inputs=[input_image])
293
+
294
+ def get_mesh(model, image, mode, pn, reso, ps, keep_edges, occ_filter_thr, fov):
295
+ depth = predict_depth(model, image, mode, pn, reso, ps)
296
+
297
+ image.thumbnail((1024,1024)) # limit the size of the input image
298
+ depth = F.interpolate(torch.from_numpy(depth).unsqueeze(dim=0).unsqueeze(dim=0), (image.height, image.width), mode='bicubic', align_corners=True).squeeze().numpy()
299
+
300
+ pts3d = depth_to_points(depth[None], fov=float(fov))
301
+ pts3d = pts3d.reshape(-1, 3)
302
+
303
+ # Create a trimesh mesh from the points
304
+ # Each pixel is connected to its 4 neighbors
305
+ # colors are the RGB values of the image
306
+
307
+ verts = pts3d.reshape(-1, 3)
308
+ image = np.array(image)
309
+ if keep_edges:
310
+ triangles = create_triangles(image.shape[0], image.shape[1])
311
+ else:
312
+ triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth, occ_filter_thr=float(occ_filter_thr)))
313
+ colors = image.reshape(-1, 3)
314
+ mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
315
+
316
+ # Save as glb
317
+ glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
318
+ glb_path = glb_file.name
319
+ mesh.export(glb_path)
320
+ return glb_path
321
+
322
+ def create_demo_3d(model):
323
+
324
+ gr.Markdown("### Image to 3D Mesh")
325
+ gr.Markdown("Convert a single 2D image to a 3D mesh")
326
+
327
+ with gr.Accordion("Advanced options", open=False):
328
+ mode = gr.Radio(["P49", "R"], label="Tiling mode", info="We recommand using P49 for fast evaluation and R with 1024 patches for best visualization results, respectively", elem_id='mode', value='R'),
329
+ patch_number = gr.Slider(1, 1024, label="Please decide the number of random patches (Only useful in mode=R)", step=1, value=256)
330
+ resolution = gr.Textbox(label="Proccessing resolution (Default 4K. Use 'x' to split height and width)", value='2160x3840')
331
+ patch_size = gr.Textbox(label="Patch size (Default 1/4 of image resolution. Use 'x' to split height and width)", value='540x960')
332
+
333
+ checkbox = gr.Checkbox(label="Keep occlusion edges", value=False)
334
+ # occ_filter_thr = gr.Textbox(label="Occlusion filter threshold", info="Larger value will reserve more edges (Only useful when NOT keeping occlusion edges)", value='0.5')
335
+ # fov = gr.Textbox(label="FOV for inv-projection", value='55')
336
+
337
+ occ_filter_thr = gr.Slider(0.01, 5, label="Occlusion edge filter threshold", info="Larger value will reserve more occlusion edges (Only useful when NOT keeping occlusion edges)", step=0.01, value=0.2)
338
+ fov = gr.Slider(5, 180, label="FOV for inv-projection", step=1, value=55)
339
+
340
+
341
+ with gr.Row():
342
+ input_image = gr.Image(label="Input Image", type='pil')
343
+ result = gr.Model3D(label="3d mesh reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
344
+
345
+ submit = gr.Button("Submit")
346
+ submit.click(partial(get_mesh, model), inputs=[input_image, mode[0], patch_number, resolution, patch_size, checkbox, occ_filter_thr, fov], outputs=[result])
347
+ examples = gr.Examples(examples=["examples/example_1.jpeg", "examples/example_4.jpeg", "examples/example_3.jpeg"], inputs=[input_image])