kxhit commited on
Commit
8cf8c7b
1 Parent(s): 5f093a6
Files changed (2) hide show
  1. EscherNet_Demo_Readme.md +0 -5
  2. app.py +780 -0
EscherNet_Demo_Readme.md DELETED
@@ -1,5 +0,0 @@
1
- Run EscherNet using Dust3R log results, need to set `data_dir` and run:
2
- ```commandline
3
- bash ./demo_dust3r.sh
4
- ```
5
-
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import gradio as gr
4
+ import os
5
+ import shutil
6
+ import rembg
7
+ import numpy as np
8
+ import math
9
+ import open3d as o3d
10
+ from PIL import Image
11
+ import torch
12
+ import torchvision
13
+ import trimesh
14
+ from skimage.io import imsave
15
+ import imageio
16
+ import cv2
17
+ import matplotlib.pyplot as pl
18
+ pl.ion()
19
+
20
+ CaPE_TYPE = "6DoF"
21
+ device = 'cuda' #if torch.cuda.is_available() else 'cpu'
22
+ weight_dtype = torch.float16
23
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
24
+
25
+ # EscherNet
26
+ # create angles in archimedean spiral with N steps
27
+ def get_archimedean_spiral(sphere_radius, num_steps=250):
28
+ # x-z plane, around upper y
29
+ '''
30
+ https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
31
+ '''
32
+ a = 40
33
+ r = sphere_radius
34
+
35
+ translations = []
36
+ angles = []
37
+
38
+ # i = a / 2
39
+ i = 0.01
40
+ while i < a:
41
+ theta = i / a * math.pi
42
+ x = r * math.sin(theta) * math.cos(-i)
43
+ z = r * math.sin(-theta + math.pi) * math.sin(-i)
44
+ y = r * - math.cos(theta)
45
+
46
+ # translations.append((x, y, z)) # origin
47
+ translations.append((x, z, -y))
48
+ angles.append([np.rad2deg(-i), np.rad2deg(theta)])
49
+
50
+ # i += a / (2 * num_steps)
51
+ i += a / (1 * num_steps)
52
+
53
+ return np.array(translations), np.stack(angles)
54
+
55
+ def look_at(origin, target, up):
56
+ forward = (target - origin)
57
+ forward = forward / np.linalg.norm(forward)
58
+ right = np.cross(up, forward)
59
+ right = right / np.linalg.norm(right)
60
+ new_up = np.cross(forward, right)
61
+ rotation_matrix = np.column_stack((right, new_up, -forward, target))
62
+ matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
63
+ return matrix
64
+
65
+ import einops
66
+ import sys
67
+
68
+ sys.path.insert(0, "./6DoF/") # TODO change it when deploying
69
+ # use the customized diffusers modules
70
+ from diffusers import DDIMScheduler
71
+ from dataset import get_pose
72
+ from CN_encoder import CN_encoder
73
+ from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
74
+
75
+ pretrained_model_name_or_path = "kxic/EscherNet_demo"
76
+ resolution = 256
77
+ h,w = resolution,resolution
78
+ guidance_scale = 3.0
79
+ radius = 2.2
80
+ bg_color = [1., 1., 1., 1.]
81
+ image_transforms = torchvision.transforms.Compose(
82
+ [
83
+ torchvision.transforms.Resize((resolution, resolution)), # 256, 256
84
+ torchvision.transforms.ToTensor(),
85
+ torchvision.transforms.Normalize([0.5], [0.5])
86
+ ]
87
+ )
88
+ xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
89
+ # only half toop
90
+ xyzs_spiral = xyzs_spiral[:100]
91
+ angles_spiral = angles_spiral[:100]
92
+
93
+ # Init pipeline
94
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
95
+ image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
96
+ pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
97
+ pretrained_model_name_or_path,
98
+ revision=None,
99
+ scheduler=scheduler,
100
+ image_encoder=None,
101
+ safety_checker=None,
102
+ feature_extractor=None,
103
+ torch_dtype=weight_dtype,
104
+ )
105
+ pipeline.image_encoder = image_encoder.to(weight_dtype)
106
+ pipeline = pipeline.to(device)
107
+ pipeline.set_progress_bar_config(disable=False)
108
+
109
+ pipeline.enable_xformers_memory_efficient_attention()
110
+ # enable vae slicing
111
+ pipeline.enable_vae_slicing()
112
+
113
+
114
+
115
+
116
+ @spaces.GPU(duration=120)
117
+ def run_eschernet(tmpdirname, eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
118
+ # set the random seed
119
+ generator = torch.Generator(device=device).manual_seed(sample_seed)
120
+ T_out = nvs_num
121
+ T_in = len(eschernet_input_dict['imgs'])
122
+ ####### output pose
123
+ # TODO choose T_out number of poses sequentially from the spiral
124
+ xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
125
+ angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
126
+
127
+ ####### input's max radius for translation scaling
128
+ radii = eschernet_input_dict['radii']
129
+ max_t = np.max(radii)
130
+ min_t = np.min(radii)
131
+
132
+ ####### input pose
133
+ pose_in = []
134
+ for T_in_index in range(T_in):
135
+ pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
136
+ pose[1:3, :] *= -1 # coordinate system conversion
137
+ pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
138
+ pose_in.append(torch.from_numpy(pose))
139
+
140
+ ####### input image
141
+ img = eschernet_input_dict['imgs'] / 255.
142
+ img[img[:, :, :, -1] == 0.] = bg_color
143
+ # TODO batch image_transforms
144
+ input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
145
+
146
+ ####### nvs pose
147
+ pose_out = []
148
+ for T_out_index in range(T_out):
149
+ azimuth, polar = angles_out[T_out_index]
150
+ if CaPE_TYPE == "4DoF":
151
+ pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
152
+ elif CaPE_TYPE == "6DoF":
153
+ pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
154
+ pose = np.linalg.inv(pose)
155
+ pose[2, :] *= -1
156
+ pose_out.append(torch.from_numpy(get_pose(pose)))
157
+
158
+
159
+
160
+ # [B, T, C, H, W]
161
+ input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
162
+ # [B, T, 4]
163
+ pose_in = np.stack(pose_in)
164
+ pose_out = np.stack(pose_out)
165
+
166
+ if CaPE_TYPE == "6DoF":
167
+ pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
168
+ pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
169
+ pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
170
+ pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
171
+
172
+ pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
173
+ pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
174
+
175
+ input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
176
+ assert T_in == input_image.shape[0]
177
+ assert T_in == pose_in.shape[1]
178
+ assert T_out == pose_out.shape[1]
179
+
180
+ # run inference
181
+ if CaPE_TYPE == "6DoF":
182
+ with torch.autocast("cuda"):
183
+ image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
184
+ poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
185
+ height=h, width=w, T_in=T_in, T_out=T_out,
186
+ guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
187
+ output_type="numpy").images
188
+ elif CaPE_TYPE == "4DoF":
189
+ with torch.autocast("cuda"):
190
+ image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
191
+ height=h, width=w, T_in=T_in, T_out=T_out,
192
+ guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
193
+ output_type="numpy").images
194
+
195
+ # save output image
196
+ output_dir = os.path.join(tmpdirname, "eschernet")
197
+ if os.path.exists(output_dir):
198
+ shutil.rmtree(output_dir)
199
+ os.makedirs(output_dir, exist_ok=True)
200
+ # save to N imgs
201
+ for i in range(T_out):
202
+ imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
203
+ # make a gif
204
+ frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
205
+ frame_one = frames[0]
206
+ frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
207
+ save_all=True, duration=50, loop=1)
208
+
209
+ # get a video
210
+ video_path = os.path.join(output_dir, "output.mp4")
211
+ imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
212
+
213
+
214
+ return image, video_path
215
+
216
+ # TODO mesh it
217
+ @spaces.GPU(duration=120)
218
+ def make3d():
219
+ pass
220
+
221
+
222
+
223
+ ############################ Dust3r as Pose Estimation ############################
224
+ from scipy.spatial.transform import Rotation
225
+ import copy
226
+
227
+ from dust3r.inference import inference
228
+ from dust3r.model import AsymmetricCroCo3DStereo
229
+ from dust3r.image_pairs import make_pairs
230
+ from dust3r.utils.image import load_images, rgb
231
+ from dust3r.utils.device import to_numpy
232
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
233
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
234
+
235
+ import functools
236
+ import math
237
+
238
+ @spaces.GPU
239
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
240
+ cam_color=None, as_pointcloud=False,
241
+ transparent_cams=False, silent=False, same_focals=False):
242
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
243
+ if not same_focals:
244
+ assert (len(cams2world) == len(focals))
245
+ pts3d = to_numpy(pts3d)
246
+ imgs = to_numpy(imgs)
247
+ focals = to_numpy(focals)
248
+ cams2world = to_numpy(cams2world)
249
+
250
+ scene = trimesh.Scene()
251
+
252
+ # add axes
253
+ scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
254
+
255
+ # full pointcloud
256
+ if as_pointcloud:
257
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
258
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
259
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
260
+ scene.add_geometry(pct)
261
+ else:
262
+ meshes = []
263
+ for i in range(len(imgs)):
264
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
265
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
266
+ scene.add_geometry(mesh)
267
+
268
+ # add each camera
269
+ for i, pose_c2w in enumerate(cams2world):
270
+ if isinstance(cam_color, list):
271
+ camera_edge_color = cam_color[i]
272
+ else:
273
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
274
+ if same_focals:
275
+ focal = focals[0]
276
+ else:
277
+ focal = focals[i]
278
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
279
+ None if transparent_cams else imgs[i], focal,
280
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
281
+
282
+ rot = np.eye(4)
283
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
284
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
285
+ outfile = os.path.join(outdir, 'scene.glb')
286
+ if not silent:
287
+ print('(exporting 3D scene to', outfile, ')')
288
+ scene.export(file_obj=outfile)
289
+ return outfile
290
+
291
+ @spaces.GPU(duration=120)
292
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
293
+ clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
294
+ """
295
+ extract 3D_model (glb file) from a reconstructed scene
296
+ """
297
+ if scene is None:
298
+ return None
299
+ # post processes
300
+ if clean_depth:
301
+ scene = scene.clean_pointcloud()
302
+ if mask_sky:
303
+ scene = scene.mask_sky()
304
+
305
+ # get optimized values from scene
306
+ rgbimg = to_numpy(scene.imgs)
307
+ focals = to_numpy(scene.get_focals().cpu())
308
+ # cams2world = to_numpy(scene.get_im_poses().cpu())
309
+ # TODO use the vis_poses
310
+ cams2world = scene.vis_poses
311
+
312
+ # 3D pointcloud from depthmap, poses and intrinsics
313
+ # pts3d = to_numpy(scene.get_pts3d())
314
+ # TODO use the vis_poses
315
+ pts3d = scene.vis_pts3d
316
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
317
+ msk = to_numpy(scene.get_masks())
318
+
319
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
320
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
321
+ same_focals=same_focals)
322
+
323
+ @spaces.GPU(duration=120)
324
+ def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
325
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
326
+ scenegraph_type, winsize, refid, same_focals):
327
+ """
328
+ from a list of images, run dust3r inference, global aligner.
329
+ then run get_3D_model_from_scene
330
+ """
331
+ # remove the directory if it already exists
332
+ if os.path.exists(outdir):
333
+ shutil.rmtree(outdir)
334
+ os.makedirs(outdir, exist_ok=True)
335
+ imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True)
336
+ if len(imgs) == 1:
337
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
338
+ imgs[1]['idx'] = 1
339
+ if scenegraph_type == "swin":
340
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
341
+ elif scenegraph_type == "oneref":
342
+ scenegraph_type = scenegraph_type + "-" + str(refid)
343
+
344
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
345
+ output = inference(pairs, model, device, batch_size=1, verbose=not silent)
346
+
347
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
348
+ scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
349
+ lr = 0.01
350
+
351
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
352
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
353
+
354
+ # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
355
+ # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
356
+
357
+ # also return rgb, depth and confidence imgs
358
+ # depth is normalized with the max value for all images
359
+ # we apply the jet colormap on the confidence maps
360
+ rgbimg = scene.imgs
361
+ # depths = to_numpy(scene.get_depthmaps())
362
+ # confs = to_numpy([c for c in scene.im_conf])
363
+ # cmap = pl.get_cmap('jet')
364
+ # depths_max = max([d.max() for d in depths])
365
+ # depths = [d / depths_max for d in depths]
366
+ # confs_max = max([d.max() for d in confs])
367
+ # confs = [cmap(d / confs_max) for d in confs]
368
+
369
+ imgs = []
370
+ rgbaimg = []
371
+ for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
372
+ imgs.append(rgbimg[i])
373
+ # imgs.append(rgb(depths[i]))
374
+ # imgs.append(rgb(confs[i]))
375
+ # imgs.append(imgs_rgba[i])
376
+ if len(imgs_rgba) == 1 and i == 1:
377
+ imgs.append(imgs_rgba[0])
378
+ rgbaimg.append(np.array(imgs_rgba[0]))
379
+ else:
380
+ imgs.append(imgs_rgba[i])
381
+ rgbaimg.append(np.array(imgs_rgba[i]))
382
+
383
+ rgbaimg = np.array(rgbaimg)
384
+
385
+ # for eschernet
386
+ # get optimized values from scene
387
+ rgbimg = to_numpy(scene.imgs)
388
+ focals = to_numpy(scene.get_focals().cpu())
389
+ cams2world = to_numpy(scene.get_im_poses().cpu())
390
+
391
+ # 3D pointcloud from depthmap, poses and intrinsics
392
+ pts3d = to_numpy(scene.get_pts3d())
393
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
394
+ msk = to_numpy(scene.get_masks())
395
+ obj_mask = rgbaimg[..., 3] > 0
396
+
397
+ # TODO set global coordinate system at the center of the scene, z-axis is up
398
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
399
+ pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
400
+ centroid = np.mean(pts_obj, axis=0) # obj center
401
+ obj2world = np.eye(4)
402
+ obj2world[:3, 3] = -centroid # T_wc
403
+
404
+ # get z_up vector
405
+ # TODO fit a plane and get the normal vector
406
+ pcd = o3d.geometry.PointCloud()
407
+ pcd.points = o3d.utility.Vector3dVector(pts)
408
+ plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
409
+ # get the normalised normal vector dim = 3
410
+ normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
411
+ # the normal direction should be pointing up
412
+ if normal[1] < 0:
413
+ normal = -normal
414
+ # print("normal", normal)
415
+
416
+ # # TODO z-up 180
417
+ # z_up = np.array([[1,0,0,0],
418
+ # [0,-1,0,0],
419
+ # [0,0,-1,0],
420
+ # [0,0,0,1]])
421
+ # obj2world = z_up @ obj2world
422
+
423
+ # # avg the y
424
+ # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
425
+ # # import pdb; pdb.set_trace()
426
+ # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
427
+ # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
428
+ # rot = Rotation.from_rotvec(rot_angle * rot_axis)
429
+ # z_up = np.eye(4)
430
+ # z_up[:3, :3] = rot.as_matrix()
431
+
432
+ # get the rotation matrix from normal to z-axis
433
+ z_axis = np.array([0, 0, 1])
434
+ rot_axis = np.cross(normal, z_axis)
435
+ rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
436
+ rot = Rotation.from_rotvec(rot_angle * rot_axis)
437
+ z_up = np.eye(4)
438
+ z_up[:3, :3] = rot.as_matrix()
439
+ obj2world = z_up @ obj2world
440
+ # flip 180
441
+ flip_rot = np.array([[1, 0, 0, 0],
442
+ [0, -1, 0, 0],
443
+ [0, 0, -1, 0],
444
+ [0, 0, 0, 1]])
445
+ obj2world = flip_rot @ obj2world
446
+
447
+ # get new cams2obj
448
+ cams2obj = []
449
+ for i, cam2world in enumerate(cams2world):
450
+ cams2obj.append(obj2world @ cam2world)
451
+ # TODO transform pts3d to the new coordinate system
452
+ for i, pts in enumerate(pts3d):
453
+ pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
454
+ -1)) \
455
+ .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
456
+ cams2world = np.array(cams2obj)
457
+ # TODO rewrite hack
458
+ scene.vis_poses = cams2world.copy()
459
+ scene.vis_pts3d = pts3d.copy()
460
+
461
+ # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
462
+ for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
463
+ np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
464
+ pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
465
+ pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
466
+ # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
467
+ # save the min/max radius of camera
468
+ radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
469
+ np.save(os.path.join(outdir, "radii.npy"), radii)
470
+
471
+ eschernet_input = {"poses": cams2world,
472
+ "radii": radii,
473
+ "imgs": rgbaimg}
474
+
475
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
476
+ clean_depth, transparent_cams, cam_size, same_focals=same_focals)
477
+
478
+ return scene, outfile, imgs, eschernet_input
479
+
480
+
481
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
482
+ num_files = len(inputfiles) if inputfiles is not None else 1
483
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
484
+ if scenegraph_type == "swin":
485
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
486
+ minimum=1, maximum=max_winsize, step=1, visible=True)
487
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
488
+ maximum=num_files - 1, step=1, visible=False)
489
+ elif scenegraph_type == "oneref":
490
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
491
+ minimum=1, maximum=max_winsize, step=1, visible=False)
492
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
493
+ maximum=num_files - 1, step=1, visible=True)
494
+ else:
495
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
496
+ minimum=1, maximum=max_winsize, step=1, visible=False)
497
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
498
+ maximum=num_files - 1, step=1, visible=False)
499
+ return winsize, refid
500
+
501
+
502
+ def get_examples(path):
503
+ objs = []
504
+ for obj_name in sorted(os.listdir(path)):
505
+ img_files = []
506
+ for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
507
+ img_files.append(os.path.join(path, obj_name, img_file))
508
+ objs.append([img_files])
509
+ print("objs = ", objs)
510
+ return objs
511
+
512
+ def preview_input(inputfiles):
513
+ if inputfiles is None:
514
+ return None
515
+ imgs = []
516
+ for img_file in inputfiles:
517
+ img = pl.imread(img_file)
518
+ imgs.append(img)
519
+ return imgs
520
+
521
+ def main():
522
+ # dustr init
523
+ silent = False
524
+ image_size = 224
525
+ weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
526
+ model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
527
+ # dust3r will write the 3D model inside tmpdirname
528
+ # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
529
+ tmpdirname = os.path.join('logs/user_object')
530
+ # remove the directory if it already exists
531
+ if os.path.exists(tmpdirname):
532
+ shutil.rmtree(tmpdirname)
533
+ os.makedirs(tmpdirname, exist_ok=True)
534
+ if not silent:
535
+ print('Outputing stuff in', tmpdirname)
536
+
537
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
538
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
539
+
540
+ generate_mvs = functools.partial(run_eschernet, tmpdirname)
541
+
542
+ _HEADER_ = '''
543
+ <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
544
+ <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
545
+
546
+ Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
547
+
548
+ <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
549
+ <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
550
+ <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
551
+
552
+ <h4><b>Tips:</b></h4>
553
+
554
+ - Our model can take <b>any number input images</b>. The more images you provide, the better the results.
555
+
556
+ - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
557
+
558
+ - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
559
+
560
+ - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
561
+
562
+ '''
563
+
564
+ _CITE_ = r"""
565
+ 📝 <b>Citation</b>:
566
+ ```bibtex
567
+ @article{kong2024eschernet,
568
+ title={EscherNet: A Generative Model for Scalable View Synthesis},
569
+ author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
570
+ journal={arXiv preprint arXiv:2402.03908},
571
+ year={2024}
572
+ }
573
+ ```
574
+ """
575
+
576
+ with gr.Blocks() as demo:
577
+ gr.Markdown(_HEADER_)
578
+ mv_images = gr.State()
579
+ scene = gr.State(None)
580
+ eschernet_input = gr.State(None)
581
+ with gr.Row(variant="panel"):
582
+ # left column
583
+ with gr.Column():
584
+ with gr.Row():
585
+ input_image = gr.File(file_count="multiple")
586
+ # with gr.Row():
587
+ # # set the size of the window
588
+ # preview_image = gr.Gallery(label='Input Views', rows=1,
589
+ with gr.Row():
590
+ run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
591
+ with gr.Row():
592
+ processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
593
+ with gr.Row(variant="panel"):
594
+ # input examples under "examples" folder
595
+ gr.Examples(
596
+ examples=get_examples('examples'),
597
+ # examples=[
598
+ # [['examples/controller/frame000077.jpg', 'examples/controller/frame000032.jpg', 'examples/controller/frame000172.jpg']],
599
+ # [['examples/hairdryer/frame000081.jpg', 'examples/hairdryer/frame000162.jpg', 'examples/hairdryer/frame000003.jpg']],
600
+ # ],
601
+ inputs=[input_image],
602
+ label="Examples (click one set of images to start!)",
603
+ examples_per_page=20
604
+ )
605
+
606
+
607
+
608
+
609
+
610
+ # right column
611
+ with gr.Column():
612
+
613
+ with gr.Row():
614
+ outmodel = gr.Model3D()
615
+
616
+ with gr.Row():
617
+ gr.Markdown('''
618
+ <h4><b>Check if the pose and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
619
+ ''')
620
+
621
+ with gr.Row():
622
+ with gr.Group():
623
+ do_remove_background = gr.Checkbox(
624
+ label="Remove Background", value=True
625
+ )
626
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
627
+
628
+ sample_steps = gr.Slider(
629
+ label="Sample Steps",
630
+ minimum=30,
631
+ maximum=75,
632
+ value=50,
633
+ step=5,
634
+ visible=False
635
+ )
636
+
637
+ nvs_num = gr.Slider(
638
+ label="Number of Novel Views",
639
+ minimum=5,
640
+ maximum=100,
641
+ value=30,
642
+ step=1
643
+ )
644
+
645
+ nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
646
+ value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
647
+
648
+ with gr.Row():
649
+ gr.Markdown('''
650
+ <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
651
+ ''')
652
+
653
+ with gr.Row():
654
+ submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
655
+
656
+ with gr.Row():
657
+ # mv_show_images = gr.Image(
658
+ # label="Generated Multi-views",
659
+ # type="pil",
660
+ # width=379,
661
+ # interactive=False
662
+ # )
663
+ with gr.Column():
664
+ output_video = gr.Video(
665
+ label="video", format="mp4",
666
+ width=379,
667
+ autoplay=True,
668
+ interactive=False
669
+ )
670
+
671
+ # with gr.Row():
672
+ # with gr.Tab("OBJ"):
673
+ # output_model_obj = gr.Model3D(
674
+ # label="Output Model (OBJ Format)",
675
+ # #width=768,
676
+ # interactive=False,
677
+ # )
678
+ # gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
679
+ # with gr.Tab("GLB"):
680
+ # output_model_glb = gr.Model3D(
681
+ # label="Output Model (GLB Format)",
682
+ # #width=768,
683
+ # interactive=False,
684
+ # )
685
+ # gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
686
+
687
+ with gr.Row():
688
+ gr.Markdown('''The novel views are generated on an archimedean spiral. You can download the video''')
689
+
690
+ gr.Markdown(_CITE_)
691
+
692
+ # set dust3r parameter invisible to be clean
693
+ with gr.Column():
694
+ with gr.Row():
695
+ schedule = gr.Dropdown(["linear", "cosine"],
696
+ value='linear', label="schedule", info="For global alignment!", visible=False)
697
+ niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
698
+ label="num_iterations", info="For global alignment!", visible=False)
699
+ scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
700
+ value='complete', label="Scenegraph",
701
+ info="Define how to make pairs",
702
+ interactive=True, visible=False)
703
+ same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
704
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
705
+ minimum=1, maximum=1, step=1, visible=False)
706
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
707
+
708
+ with gr.Row():
709
+ # adjust the confidence threshold
710
+ min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
711
+ # adjust the camera size in the output pointcloud
712
+ cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
713
+ with gr.Row():
714
+ as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
715
+ # two post process implemented
716
+ mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
717
+ clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
718
+ transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
719
+
720
+ # events
721
+ # scenegraph_type.change(set_scenegraph_options,
722
+ # inputs=[input_image, winsize, refid, scenegraph_type],
723
+ # outputs=[winsize, refid])
724
+ input_image.change(set_scenegraph_options,
725
+ inputs=[input_image, winsize, refid, scenegraph_type],
726
+ outputs=[winsize, refid])
727
+ # min_conf_thr.release(fn=model_from_scene_fun,
728
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
729
+ # clean_depth, transparent_cams, cam_size, same_focals],
730
+ # outputs=outmodel)
731
+ # cam_size.change(fn=model_from_scene_fun,
732
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
733
+ # clean_depth, transparent_cams, cam_size, same_focals],
734
+ # outputs=outmodel)
735
+ # as_pointcloud.change(fn=model_from_scene_fun,
736
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
737
+ # clean_depth, transparent_cams, cam_size, same_focals],
738
+ # outputs=outmodel)
739
+ # mask_sky.change(fn=model_from_scene_fun,
740
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
741
+ # clean_depth, transparent_cams, cam_size, same_focals],
742
+ # outputs=outmodel)
743
+ # clean_depth.change(fn=model_from_scene_fun,
744
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
745
+ # clean_depth, transparent_cams, cam_size, same_focals],
746
+ # outputs=outmodel)
747
+ # transparent_cams.change(model_from_scene_fun,
748
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
749
+ # clean_depth, transparent_cams, cam_size, same_focals],
750
+ # outputs=outmodel)
751
+ run_dust3r.click(fn=recon_fun,
752
+ inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
753
+ mask_sky, clean_depth, transparent_cams, cam_size,
754
+ scenegraph_type, winsize, refid, same_focals],
755
+ outputs=[scene, outmodel, processed_image, eschernet_input])
756
+
757
+
758
+ # events
759
+ # preview images on input change
760
+ input_image.change(fn=preview_input,
761
+ inputs=[input_image],
762
+ outputs=[processed_image])
763
+
764
+ submit.click(fn=generate_mvs,
765
+ inputs=[eschernet_input, sample_steps, sample_seed,
766
+ nvs_num, nvs_mode],
767
+ outputs=[mv_images, output_video],
768
+ )#.success(
769
+ # # fn=make3d,
770
+ # # inputs=[mv_images],
771
+ # # outputs=[output_video, output_model_obj, output_model_glb]
772
+ # # )
773
+
774
+
775
+
776
+ demo.queue(max_size=10)
777
+ demo.launch(share=True, server_name="0.0.0.0", server_port=None)
778
+
779
+ if __name__ == '__main__':
780
+ main()