hugoycj commited on
Commit
270a9a7
1 Parent(s): 2caa1bd

feat: Add mast3r for refinement

Browse files
Files changed (1) hide show
  1. app.py +269 -63
app.py CHANGED
@@ -8,9 +8,15 @@ import tempfile
8
  import subprocess
9
  from dust3r.losses import L21
10
  from spann3r.model import Spann3R
 
 
11
  from spann3r.datasets import Demo
12
  from torch.utils.data import DataLoader
13
- import trimesh
 
 
 
 
14
  from scipy.spatial.transform import Rotation
15
  from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
@@ -22,9 +28,16 @@ from pose_utils import solve_cemara
22
  from gradio.helpers import Examples as GradioExamples
23
  from gradio.utils import get_cache_folder
24
  from pathlib import Path
 
 
 
 
 
 
25
  # Default values
26
- DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
27
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
 
28
  DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
29
 
30
  OPENGL = np.array([[1, 0, 0, 0],
@@ -128,17 +141,6 @@ def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) ->
128
  subprocess.run(command, check=True)
129
  return temp_dir
130
 
131
- def cat_meshes(meshes):
132
- vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
133
- n_vertices = np.cumsum([0]+[len(v) for v in vertices])
134
- for i in range(len(faces)):
135
- faces[i][:] += n_vertices[i]
136
-
137
- vertices = np.concatenate(vertices)
138
- colors = np.concatenate(colors)
139
- faces = np.concatenate(faces)
140
- return dict(vertices=vertices, face_colors=colors, faces=faces)
141
-
142
  def load_ckpt(model_path_or_url, verbose=True):
143
  if verbose:
144
  print('... loading model from', model_path_or_url)
@@ -158,46 +160,10 @@ def load_model(ckpt_path, device):
158
  model.eval()
159
  return model
160
 
161
- def pts3d_to_trimesh(img, pts3d, valid=None):
162
- H, W, THREE = img.shape
163
- assert THREE == 3
164
- assert img.shape == pts3d.shape
165
-
166
- vertices = pts3d.reshape(-1, 3)
167
-
168
- # make squares: each pixel == 2 triangles
169
- idx = np.arange(len(vertices)).reshape(H, W)
170
- idx1 = idx[:-1, :-1].ravel() # top-left corner
171
- idx2 = idx[:-1, +1:].ravel() # right-left corner
172
- idx3 = idx[+1:, :-1].ravel() # bottom-left corner
173
- idx4 = idx[+1:, +1:].ravel() # bottom-right corner
174
- faces = np.concatenate((
175
- np.c_[idx1, idx2, idx3],
176
- np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
177
- np.c_[idx2, idx3, idx4],
178
- np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
179
- ), axis=0)
180
-
181
- # prepare triangle colors
182
- face_colors = np.concatenate((
183
- img[:-1, :-1].reshape(-1, 3),
184
- img[:-1, :-1].reshape(-1, 3),
185
- img[+1:, +1:].reshape(-1, 3),
186
- img[+1:, +1:].reshape(-1, 3)
187
- ), axis=0)
188
-
189
- # remove invalid faces
190
- if valid is not None:
191
- assert valid.shape == (H, W)
192
- valid_idxs = valid.ravel()
193
- valid_faces = valid_idxs[faces].all(axis=-1)
194
- faces = faces[valid_faces]
195
- face_colors = face_colors[valid_faces]
196
-
197
- assert len(faces) == len(face_colors)
198
- return dict(vertices=vertices, face_colors=face_colors, faces=faces)
199
-
200
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
 
 
 
201
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
202
  birefnet.to(DEFAULT_DEVICE)
203
  birefnet.eval()
@@ -304,6 +270,204 @@ def center_mesh(mesh: o3d.geometry.TriangleMesh, normalize=False) -> o3d.geometr
304
  mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
305
  return mesh
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  @torch.no_grad()
308
  def reconstruct(video_path, conf_thresh, kf_every,
309
  remove_background=False, enable_registration=True, output_3d_model=True):
@@ -329,13 +493,46 @@ def reconstruct(video_path, conf_thresh, kf_every,
329
 
330
  # Process results
331
  pcds = []
 
332
  cameras_all = []
 
 
333
  last_focal = None
 
 
 
 
 
 
 
 
 
 
 
 
334
  for j, view in enumerate(batch):
335
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
336
  image = (image + 1) / 2
 
337
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
338
  pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  conf = preds[j]['conf'][0].cpu().data.numpy()
340
  conf_sig = (conf - 1) / conf
341
  if remove_background:
@@ -353,9 +550,15 @@ def reconstruct(video_path, conf_thresh, kf_every,
353
  pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
354
  pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
355
  pcds.append(pcd)
 
 
356
  cameras_all.append(camera)
357
-
358
 
 
 
 
 
 
359
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
360
  o3d_geometry = point2mesh(pcd_combined)
361
  o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
@@ -367,17 +570,14 @@ def reconstruct(video_path, conf_thresh, kf_every,
367
  pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
368
  pcd_combined = center_pcd(pcd_combined)
369
 
 
370
  if output_3d_model:
371
  gs_output_path = tempfile.mktemp(suffix='.ply')
372
  point2gs(gs_output_path, pcd_combined)
373
- # Create 3D model result using gaussian splatting
374
- return coarse_output_path, gs_output_path
375
  else:
376
  pcd_output_path = export_geometry(pcd_combined, file_format='ply')
377
- return coarse_output_path, pcd_output_path
378
-
379
- # Clean up temporary directory
380
- os.system(f"rm -rf {demo_path}")
381
 
382
  example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
383
 
@@ -461,6 +661,7 @@ with gr.Blocks(
461
  info="Generate Splat (PLY) instead of Point Cloud (PLY)"
462
  )
463
  reconstruct_btn = gr.Button("Start Reconstruction")
 
464
 
465
  with gr.Column(scale=2):
466
  with gr.Tab("3D Models"):
@@ -472,10 +673,8 @@ with gr.Blocks(
472
  )
473
 
474
  with gr.Group():
475
- output_model = gr.Model3D(
476
- label="Reconstructed PointCloud or Splat",
477
- display_mode="solid",
478
- clear_color=[0.0, 0.0, 0.0, 0.0]
479
  )
480
 
481
  Examples(
@@ -495,6 +694,13 @@ with gr.Blocks(
495
  inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
496
  outputs=[initial_model, output_model]
497
  )
 
 
 
 
 
 
 
498
 
499
  if __name__ == "__main__":
500
  iface.launch(server_name="0.0.0.0")
 
8
  import subprocess
9
  from dust3r.losses import L21
10
  from spann3r.model import Spann3R
11
+ from mast3r.model import AsymmetricMASt3R
12
+
13
  from spann3r.datasets import Demo
14
  from torch.utils.data import DataLoader
15
+ import cv2
16
+ import json
17
+ import glob
18
+ from dust3r.post_process import estimate_focal_knowing_depth
19
+ from mast3r.demo import get_reconstructed_scene
20
  from scipy.spatial.transform import Rotation
21
  from transformers import AutoModelForImageSegmentation
22
  from torchvision import transforms
 
28
  from gradio.helpers import Examples as GradioExamples
29
  from gradio.utils import get_cache_folder
30
  from pathlib import Path
31
+ import os
32
+ import shutil
33
+ import math
34
+ import zipfile
35
+ from pathlib import Path
36
+
37
  # Default values
38
+ DEFAULT_CKPT_PATH = 'checkpoints/spann3r.pth'
39
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
40
+ DEFAULT_MAST3R_PATH = 'checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
41
  DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
42
 
43
  OPENGL = np.array([[1, 0, 0, 0],
 
141
  subprocess.run(command, check=True)
142
  return temp_dir
143
 
 
 
 
 
 
 
 
 
 
 
 
144
  def load_ckpt(model_path_or_url, verbose=True):
145
  if verbose:
146
  print('... loading model from', model_path_or_url)
 
160
  model.eval()
161
  return model
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
164
+ mast3r_model = AsymmetricMASt3R.from_pretrained(DEFAULT_MAST3R_PATH).to(DEFAULT_DEVICE)
165
+ mast3r_model.eval()
166
+
167
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
168
  birefnet.to(DEFAULT_DEVICE)
169
  birefnet.eval()
 
270
  mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
271
  return mesh
272
 
273
+ def get_transform_json(H, W, focal, poses_all):
274
+ transform_dict = {
275
+ 'w': W,
276
+ 'h': H,
277
+ 'fl_x': focal.item(),
278
+ 'fl_y': focal.item(),
279
+ 'cx': W/2,
280
+ 'cy': H/2,
281
+ }
282
+ frames = []
283
+
284
+ for i, pose in enumerate(poses_all):
285
+ # CV2 GL format
286
+ pose[:3, 1] *= -1
287
+ pose[:3, 2] *= -1
288
+ frame = {
289
+ 'w': W,
290
+ 'h': H,
291
+ 'fl_x': focal.item(),
292
+ 'fl_y': focal.item(),
293
+ 'cx': W/2,
294
+ 'cy': H/2,
295
+ 'file_path': f"images/{i:04d}.jpg",
296
+ "mask_path": f"masks/{i:04d}.png",
297
+ 'transform_matrix': pose.tolist()
298
+ }
299
+ frames.append(frame)
300
+ transform_dict['frames'] = frames
301
+
302
+ return transform_dict
303
+
304
+ def organize_and_zip_output(images_all, masks_all, transform_json_path, output_dir=None):
305
+ """
306
+ Organizes reconstruction outputs into a specific directory structure and creates a zip file.
307
+
308
+ Args:
309
+ images_all: List of numpy arrays containing images
310
+ masks_all: List of numpy arrays containing masks
311
+ transform_json_path: Path to the transform.json file
312
+ output_dir: Optional custom output directory name
313
+
314
+ Returns:
315
+ str: Path to the created zip file
316
+ """
317
+ try:
318
+ # Create temporary directory with timestamp
319
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
320
+ base_dir = output_dir or f"reconstruction_{timestamp}"
321
+ os.makedirs(base_dir, exist_ok=True)
322
+
323
+ # Create subdirectories
324
+ images_dir = os.path.join(base_dir, "images")
325
+ masks_dir = os.path.join(base_dir, "masks")
326
+ os.makedirs(images_dir, exist_ok=True)
327
+ os.makedirs(masks_dir, exist_ok=True)
328
+
329
+ # Save images
330
+ for i, image in enumerate(images_all):
331
+ image_path = os.path.join(images_dir, f"{i:04d}.jpg")
332
+ cv2.imwrite(image_path, (image * 255).astype(np.uint8)[..., ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 90])
333
+
334
+ # Save masks
335
+ for i, mask in enumerate(masks_all):
336
+ mask_path = os.path.join(masks_dir, f"{i:04d}.png")
337
+ cv2.imwrite(mask_path, (mask * 255).astype(np.uint8))
338
+
339
+ # Copy transform.json
340
+ shutil.copy2(transform_json_path, os.path.join(base_dir, "transforms.json"))
341
+
342
+ # Create zip file
343
+ zip_path = f"{base_dir}.zip"
344
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
345
+ for root, _, files in os.walk(base_dir):
346
+ for file in files:
347
+ file_path = os.path.join(root, file)
348
+ arcname = os.path.relpath(file_path, base_dir)
349
+ zipf.write(file_path, arcname)
350
+
351
+ return zip_path
352
+
353
+ finally:
354
+ # Clean up temporary directories and files
355
+ if os.path.exists(base_dir):
356
+ shutil.rmtree(base_dir)
357
+ if os.path.exists(transform_json_path):
358
+ os.remove(transform_json_path)
359
+
360
+ def get_keyframes(temp_dir: str, kf_every: int = 10):
361
+ """
362
+ Select keyframes from a directory of extracted frames at specified intervals
363
+
364
+ Args:
365
+ temp_dir: Directory containing extracted frames (named as 001.jpg, 002.jpg, etc.)
366
+ kf_every: Select every Nth frame as a keyframe
367
+
368
+ Returns:
369
+ List[str]: Sorted list of paths to selected keyframe images
370
+ """
371
+ # Get all jpg files in the directory
372
+ frame_paths = glob.glob(os.path.join(temp_dir, "*.jpg"))
373
+
374
+ # Sort frames by number to ensure correct order
375
+ frame_paths.sort(key=lambda x: int(Path(x).stem))
376
+
377
+ # Select keyframes at specified interval
378
+ keyframe_paths = frame_paths[::kf_every]
379
+
380
+ # Ensure we have at least 2 frames for reconstruction
381
+ if len(keyframe_paths) < 2:
382
+ if len(frame_paths) >= 2:
383
+ # If we have at least 2 frames, use first and last
384
+ keyframe_paths = [frame_paths[0], frame_paths[-1]]
385
+ else:
386
+ raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.")
387
+
388
+ return keyframe_paths
389
+
390
+ from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
391
+ from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
392
+ from dust3r.utils.image import load_images
393
+ from dust3r.image_pairs import make_pairs
394
+ from dust3r.utils.device import to_numpy
395
+ def invert_matrix(mat):
396
+ """Invert a torch or numpy matrix."""
397
+ if isinstance(mat, torch.Tensor):
398
+ return torch.linalg.inv(mat)
399
+ if isinstance(mat, np.ndarray):
400
+ return np.linalg.inv(mat)
401
+ raise ValueError(f'Unsupported matrix type: {type(mat)}')
402
+
403
+ def refine(
404
+ video_path: str,
405
+ conf_thresh: float = 5.0,
406
+ kf_every: int = 30,
407
+ remove_background: bool = False,
408
+ enable_registration: bool = True,
409
+ output_3d_model: bool = True
410
+ ) -> dict:
411
+ # Extract keyframes from video
412
+ temp_dir = extract_frames(video_path)
413
+ keyframe_paths = get_keyframes(temp_dir, kf_every*3)
414
+
415
+ image_size = 512
416
+ images = load_images(keyframe_paths, size=image_size)
417
+
418
+ # Create output directory
419
+ output_dir = tempfile.mkdtemp()
420
+
421
+ # Generate pairs and run inference
422
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
423
+ cache_dir = os.path.join(output_dir, 'cache')
424
+ if os.path.exists(cache_dir):
425
+ os.system(f'rm -rf {cache_dir}')
426
+ scene = sparse_global_alignment(keyframe_paths, pairs, cache_dir,
427
+ mast3r_model, lr1=0.07, niter1=500, lr2=0.014,
428
+ niter2=200 if enable_registration else 0, device=DEFAULT_DEVICE,
429
+ opt_depth=True if enable_registration else False, shared_intrinsics=True,
430
+ matching_conf_thr=5.)
431
+
432
+ # Extract scene information
433
+ imgs = np.array(scene.imgs)
434
+
435
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=0)
436
+ pts3d, _, confs = tsdf.get_dense_pts3d(clean_depth=True)
437
+ masks = np.array(to_numpy([c > 1.5 for c in confs]))
438
+
439
+ pcds = []
440
+ for pts, conf_mask, image in zip(pts3d, masks, imgs):
441
+ if remove_background:
442
+ mask = generate_mask(image)
443
+ else:
444
+ mask = np.ones_like(conf_mask)
445
+ combined_mask = conf_mask & (mask > 0.5)
446
+
447
+ pts = pts.reshape(combined_mask.shape[0], combined_mask.shape[1], 3)
448
+ pts_normal = pts2normal(pts).cpu().numpy()
449
+ pts = pts.cpu().numpy()
450
+ pcd = o3d.geometry.PointCloud()
451
+ pcd.points = o3d.utility.Vector3dVector(pts[combined_mask] / 5)
452
+ pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
453
+ pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
454
+ pcds.append(pcd)
455
+
456
+ pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
457
+ o3d_geometry = point2mesh(pcd_combined, depth=9)
458
+ o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
459
+
460
+ # Create coarse result
461
+ coarse_output_path = export_geometry(o3d_geometry_centered)
462
+
463
+ if output_3d_model:
464
+ gs_output_path = tempfile.mktemp(suffix='.ply')
465
+ point2gs(gs_output_path, pcd_combined)
466
+ return coarse_output_path, [gs_output_path]
467
+ else:
468
+ pcd_output_path = export_geometry(pcd_combined, file_format='ply')
469
+ return coarse_output_path, [pcd_output_path]
470
+
471
  @torch.no_grad()
472
  def reconstruct(video_path, conf_thresh, kf_every,
473
  remove_background=False, enable_registration=True, output_3d_model=True):
 
493
 
494
  # Process results
495
  pcds = []
496
+ poses_all = []
497
  cameras_all = []
498
+ images_all = []
499
+ masks_all = []
500
  last_focal = None
501
+
502
+ ##### estimate focal length
503
+ _, H, W, _ = preds[0]['pts3d'].shape
504
+ pp = torch.tensor((W/2, H/2))
505
+ focal = estimate_focal_knowing_depth(preds[0]['pts3d'].cpu(), pp, focal_mode='weiszfeld')
506
+ print(f'Estimated focal of first camera: {focal.item()} (224x224)')
507
+
508
+ intrinsic = np.eye(3)
509
+ intrinsic[0, 0] = focal
510
+ intrinsic[1, 1] = focal
511
+ intrinsic[:2, 2] = pp
512
+
513
  for j, view in enumerate(batch):
514
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
515
  image = (image + 1) / 2
516
+ mask = view['valid_mask'].cpu().numpy()[0]
517
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
518
  pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
519
+
520
+ ##### Solve PnP-RANSAC
521
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
522
+ points_2d = np.stack((u, v), axis=-1)
523
+ dist_coeffs = np.zeros(4).astype(np.float32)
524
+ success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac(
525
+ pts.reshape(-1, 3).astype(np.float32),
526
+ points_2d.reshape(-1, 2).astype(np.float32),
527
+ intrinsic.astype(np.float32),
528
+ dist_coeffs)
529
+
530
+ rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
531
+ # Extrinsic parameters (4x4 matrix)
532
+ extrinsic_matrix = np.hstack((rotation_matrix, translation_vector.reshape(-1, 1)))
533
+ extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1]))
534
+ poses_all.append(np.linalg.inv(extrinsic_matrix))
535
+
536
  conf = preds[j]['conf'][0].cpu().data.numpy()
537
  conf_sig = (conf - 1) / conf
538
  if remove_background:
 
550
  pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
551
  pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
552
  pcds.append(pcd)
553
+ images_all.append(image)
554
+ masks_all.append(mask)
555
  cameras_all.append(camera)
 
556
 
557
+ transform_dict = get_transform_json(H, W, focal, poses_all)
558
+ temp_json_file = tempfile.mktemp(suffix='.json')
559
+ with open(os.path.join(temp_json_file), 'w') as f:
560
+ json.dump(transform_dict, f, indent=4)
561
+
562
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
563
  o3d_geometry = point2mesh(pcd_combined)
564
  o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
 
570
  pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
571
  pcd_combined = center_pcd(pcd_combined)
572
 
573
+ # zip_path = organize_and_zip_output(images_all, masks_all, temp_json_file)
574
  if output_3d_model:
575
  gs_output_path = tempfile.mktemp(suffix='.ply')
576
  point2gs(gs_output_path, pcd_combined)
577
+ return coarse_output_path, [gs_output_path]
 
578
  else:
579
  pcd_output_path = export_geometry(pcd_combined, file_format='ply')
580
+ return coarse_output_path, [pcd_output_path]
 
 
 
581
 
582
  example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
583
 
 
661
  info="Generate Splat (PLY) instead of Point Cloud (PLY)"
662
  )
663
  reconstruct_btn = gr.Button("Start Reconstruction")
664
+ refine_btn = gr.Button("Start Refinement")
665
 
666
  with gr.Column(scale=2):
667
  with gr.Tab("3D Models"):
 
673
  )
674
 
675
  with gr.Group():
676
+ output_model = gr.File(
677
+ label="Reconstructed Results",
 
 
678
  )
679
 
680
  Examples(
 
694
  inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
695
  outputs=[initial_model, output_model]
696
  )
697
+
698
+ refine_btn.click(
699
+ fn=refine,
700
+ inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
701
+ outputs=[initial_model, output_model]
702
+ )
703
+
704
 
705
  if __name__ == "__main__":
706
  iface.launch(server_name="0.0.0.0")