Spaces:
Sleeping
Sleeping
hugoycj
commited on
Commit
•
270a9a7
1
Parent(s):
2caa1bd
feat: Add mast3r for refinement
Browse files
app.py
CHANGED
@@ -8,9 +8,15 @@ import tempfile
|
|
8 |
import subprocess
|
9 |
from dust3r.losses import L21
|
10 |
from spann3r.model import Spann3R
|
|
|
|
|
11 |
from spann3r.datasets import Demo
|
12 |
from torch.utils.data import DataLoader
|
13 |
-
import
|
|
|
|
|
|
|
|
|
14 |
from scipy.spatial.transform import Rotation
|
15 |
from transformers import AutoModelForImageSegmentation
|
16 |
from torchvision import transforms
|
@@ -22,9 +28,16 @@ from pose_utils import solve_cemara
|
|
22 |
from gradio.helpers import Examples as GradioExamples
|
23 |
from gradio.utils import get_cache_folder
|
24 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# Default values
|
26 |
-
DEFAULT_CKPT_PATH = '
|
27 |
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
|
|
|
28 |
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
29 |
|
30 |
OPENGL = np.array([[1, 0, 0, 0],
|
@@ -128,17 +141,6 @@ def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) ->
|
|
128 |
subprocess.run(command, check=True)
|
129 |
return temp_dir
|
130 |
|
131 |
-
def cat_meshes(meshes):
|
132 |
-
vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
|
133 |
-
n_vertices = np.cumsum([0]+[len(v) for v in vertices])
|
134 |
-
for i in range(len(faces)):
|
135 |
-
faces[i][:] += n_vertices[i]
|
136 |
-
|
137 |
-
vertices = np.concatenate(vertices)
|
138 |
-
colors = np.concatenate(colors)
|
139 |
-
faces = np.concatenate(faces)
|
140 |
-
return dict(vertices=vertices, face_colors=colors, faces=faces)
|
141 |
-
|
142 |
def load_ckpt(model_path_or_url, verbose=True):
|
143 |
if verbose:
|
144 |
print('... loading model from', model_path_or_url)
|
@@ -158,46 +160,10 @@ def load_model(ckpt_path, device):
|
|
158 |
model.eval()
|
159 |
return model
|
160 |
|
161 |
-
def pts3d_to_trimesh(img, pts3d, valid=None):
|
162 |
-
H, W, THREE = img.shape
|
163 |
-
assert THREE == 3
|
164 |
-
assert img.shape == pts3d.shape
|
165 |
-
|
166 |
-
vertices = pts3d.reshape(-1, 3)
|
167 |
-
|
168 |
-
# make squares: each pixel == 2 triangles
|
169 |
-
idx = np.arange(len(vertices)).reshape(H, W)
|
170 |
-
idx1 = idx[:-1, :-1].ravel() # top-left corner
|
171 |
-
idx2 = idx[:-1, +1:].ravel() # right-left corner
|
172 |
-
idx3 = idx[+1:, :-1].ravel() # bottom-left corner
|
173 |
-
idx4 = idx[+1:, +1:].ravel() # bottom-right corner
|
174 |
-
faces = np.concatenate((
|
175 |
-
np.c_[idx1, idx2, idx3],
|
176 |
-
np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
|
177 |
-
np.c_[idx2, idx3, idx4],
|
178 |
-
np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
|
179 |
-
), axis=0)
|
180 |
-
|
181 |
-
# prepare triangle colors
|
182 |
-
face_colors = np.concatenate((
|
183 |
-
img[:-1, :-1].reshape(-1, 3),
|
184 |
-
img[:-1, :-1].reshape(-1, 3),
|
185 |
-
img[+1:, +1:].reshape(-1, 3),
|
186 |
-
img[+1:, +1:].reshape(-1, 3)
|
187 |
-
), axis=0)
|
188 |
-
|
189 |
-
# remove invalid faces
|
190 |
-
if valid is not None:
|
191 |
-
assert valid.shape == (H, W)
|
192 |
-
valid_idxs = valid.ravel()
|
193 |
-
valid_faces = valid_idxs[faces].all(axis=-1)
|
194 |
-
faces = faces[valid_faces]
|
195 |
-
face_colors = face_colors[valid_faces]
|
196 |
-
|
197 |
-
assert len(faces) == len(face_colors)
|
198 |
-
return dict(vertices=vertices, face_colors=face_colors, faces=faces)
|
199 |
-
|
200 |
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
|
|
|
|
|
|
|
201 |
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
|
202 |
birefnet.to(DEFAULT_DEVICE)
|
203 |
birefnet.eval()
|
@@ -304,6 +270,204 @@ def center_mesh(mesh: o3d.geometry.TriangleMesh, normalize=False) -> o3d.geometr
|
|
304 |
mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
|
305 |
return mesh
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
@torch.no_grad()
|
308 |
def reconstruct(video_path, conf_thresh, kf_every,
|
309 |
remove_background=False, enable_registration=True, output_3d_model=True):
|
@@ -329,13 +493,46 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
329 |
|
330 |
# Process results
|
331 |
pcds = []
|
|
|
332 |
cameras_all = []
|
|
|
|
|
333 |
last_focal = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
for j, view in enumerate(batch):
|
335 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
336 |
image = (image + 1) / 2
|
|
|
337 |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
|
338 |
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
conf = preds[j]['conf'][0].cpu().data.numpy()
|
340 |
conf_sig = (conf - 1) / conf
|
341 |
if remove_background:
|
@@ -353,9 +550,15 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
353 |
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
354 |
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
355 |
pcds.append(pcd)
|
|
|
|
|
356 |
cameras_all.append(camera)
|
357 |
-
|
358 |
|
|
|
|
|
|
|
|
|
|
|
359 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
360 |
o3d_geometry = point2mesh(pcd_combined)
|
361 |
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
|
@@ -367,17 +570,14 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
367 |
pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
368 |
pcd_combined = center_pcd(pcd_combined)
|
369 |
|
|
|
370 |
if output_3d_model:
|
371 |
gs_output_path = tempfile.mktemp(suffix='.ply')
|
372 |
point2gs(gs_output_path, pcd_combined)
|
373 |
-
|
374 |
-
return coarse_output_path, gs_output_path
|
375 |
else:
|
376 |
pcd_output_path = export_geometry(pcd_combined, file_format='ply')
|
377 |
-
return coarse_output_path, pcd_output_path
|
378 |
-
|
379 |
-
# Clean up temporary directory
|
380 |
-
os.system(f"rm -rf {demo_path}")
|
381 |
|
382 |
example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
|
383 |
|
@@ -461,6 +661,7 @@ with gr.Blocks(
|
|
461 |
info="Generate Splat (PLY) instead of Point Cloud (PLY)"
|
462 |
)
|
463 |
reconstruct_btn = gr.Button("Start Reconstruction")
|
|
|
464 |
|
465 |
with gr.Column(scale=2):
|
466 |
with gr.Tab("3D Models"):
|
@@ -472,10 +673,8 @@ with gr.Blocks(
|
|
472 |
)
|
473 |
|
474 |
with gr.Group():
|
475 |
-
output_model = gr.
|
476 |
-
label="Reconstructed
|
477 |
-
display_mode="solid",
|
478 |
-
clear_color=[0.0, 0.0, 0.0, 0.0]
|
479 |
)
|
480 |
|
481 |
Examples(
|
@@ -495,6 +694,13 @@ with gr.Blocks(
|
|
495 |
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
|
496 |
outputs=[initial_model, output_model]
|
497 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
if __name__ == "__main__":
|
500 |
iface.launch(server_name="0.0.0.0")
|
|
|
8 |
import subprocess
|
9 |
from dust3r.losses import L21
|
10 |
from spann3r.model import Spann3R
|
11 |
+
from mast3r.model import AsymmetricMASt3R
|
12 |
+
|
13 |
from spann3r.datasets import Demo
|
14 |
from torch.utils.data import DataLoader
|
15 |
+
import cv2
|
16 |
+
import json
|
17 |
+
import glob
|
18 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
19 |
+
from mast3r.demo import get_reconstructed_scene
|
20 |
from scipy.spatial.transform import Rotation
|
21 |
from transformers import AutoModelForImageSegmentation
|
22 |
from torchvision import transforms
|
|
|
28 |
from gradio.helpers import Examples as GradioExamples
|
29 |
from gradio.utils import get_cache_folder
|
30 |
from pathlib import Path
|
31 |
+
import os
|
32 |
+
import shutil
|
33 |
+
import math
|
34 |
+
import zipfile
|
35 |
+
from pathlib import Path
|
36 |
+
|
37 |
# Default values
|
38 |
+
DEFAULT_CKPT_PATH = 'checkpoints/spann3r.pth'
|
39 |
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
|
40 |
+
DEFAULT_MAST3R_PATH = 'checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
|
41 |
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
42 |
|
43 |
OPENGL = np.array([[1, 0, 0, 0],
|
|
|
141 |
subprocess.run(command, check=True)
|
142 |
return temp_dir
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
def load_ckpt(model_path_or_url, verbose=True):
|
145 |
if verbose:
|
146 |
print('... loading model from', model_path_or_url)
|
|
|
160 |
model.eval()
|
161 |
return model
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
|
164 |
+
mast3r_model = AsymmetricMASt3R.from_pretrained(DEFAULT_MAST3R_PATH).to(DEFAULT_DEVICE)
|
165 |
+
mast3r_model.eval()
|
166 |
+
|
167 |
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
|
168 |
birefnet.to(DEFAULT_DEVICE)
|
169 |
birefnet.eval()
|
|
|
270 |
mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
|
271 |
return mesh
|
272 |
|
273 |
+
def get_transform_json(H, W, focal, poses_all):
|
274 |
+
transform_dict = {
|
275 |
+
'w': W,
|
276 |
+
'h': H,
|
277 |
+
'fl_x': focal.item(),
|
278 |
+
'fl_y': focal.item(),
|
279 |
+
'cx': W/2,
|
280 |
+
'cy': H/2,
|
281 |
+
}
|
282 |
+
frames = []
|
283 |
+
|
284 |
+
for i, pose in enumerate(poses_all):
|
285 |
+
# CV2 GL format
|
286 |
+
pose[:3, 1] *= -1
|
287 |
+
pose[:3, 2] *= -1
|
288 |
+
frame = {
|
289 |
+
'w': W,
|
290 |
+
'h': H,
|
291 |
+
'fl_x': focal.item(),
|
292 |
+
'fl_y': focal.item(),
|
293 |
+
'cx': W/2,
|
294 |
+
'cy': H/2,
|
295 |
+
'file_path': f"images/{i:04d}.jpg",
|
296 |
+
"mask_path": f"masks/{i:04d}.png",
|
297 |
+
'transform_matrix': pose.tolist()
|
298 |
+
}
|
299 |
+
frames.append(frame)
|
300 |
+
transform_dict['frames'] = frames
|
301 |
+
|
302 |
+
return transform_dict
|
303 |
+
|
304 |
+
def organize_and_zip_output(images_all, masks_all, transform_json_path, output_dir=None):
|
305 |
+
"""
|
306 |
+
Organizes reconstruction outputs into a specific directory structure and creates a zip file.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
images_all: List of numpy arrays containing images
|
310 |
+
masks_all: List of numpy arrays containing masks
|
311 |
+
transform_json_path: Path to the transform.json file
|
312 |
+
output_dir: Optional custom output directory name
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
str: Path to the created zip file
|
316 |
+
"""
|
317 |
+
try:
|
318 |
+
# Create temporary directory with timestamp
|
319 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
320 |
+
base_dir = output_dir or f"reconstruction_{timestamp}"
|
321 |
+
os.makedirs(base_dir, exist_ok=True)
|
322 |
+
|
323 |
+
# Create subdirectories
|
324 |
+
images_dir = os.path.join(base_dir, "images")
|
325 |
+
masks_dir = os.path.join(base_dir, "masks")
|
326 |
+
os.makedirs(images_dir, exist_ok=True)
|
327 |
+
os.makedirs(masks_dir, exist_ok=True)
|
328 |
+
|
329 |
+
# Save images
|
330 |
+
for i, image in enumerate(images_all):
|
331 |
+
image_path = os.path.join(images_dir, f"{i:04d}.jpg")
|
332 |
+
cv2.imwrite(image_path, (image * 255).astype(np.uint8)[..., ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
333 |
+
|
334 |
+
# Save masks
|
335 |
+
for i, mask in enumerate(masks_all):
|
336 |
+
mask_path = os.path.join(masks_dir, f"{i:04d}.png")
|
337 |
+
cv2.imwrite(mask_path, (mask * 255).astype(np.uint8))
|
338 |
+
|
339 |
+
# Copy transform.json
|
340 |
+
shutil.copy2(transform_json_path, os.path.join(base_dir, "transforms.json"))
|
341 |
+
|
342 |
+
# Create zip file
|
343 |
+
zip_path = f"{base_dir}.zip"
|
344 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
345 |
+
for root, _, files in os.walk(base_dir):
|
346 |
+
for file in files:
|
347 |
+
file_path = os.path.join(root, file)
|
348 |
+
arcname = os.path.relpath(file_path, base_dir)
|
349 |
+
zipf.write(file_path, arcname)
|
350 |
+
|
351 |
+
return zip_path
|
352 |
+
|
353 |
+
finally:
|
354 |
+
# Clean up temporary directories and files
|
355 |
+
if os.path.exists(base_dir):
|
356 |
+
shutil.rmtree(base_dir)
|
357 |
+
if os.path.exists(transform_json_path):
|
358 |
+
os.remove(transform_json_path)
|
359 |
+
|
360 |
+
def get_keyframes(temp_dir: str, kf_every: int = 10):
|
361 |
+
"""
|
362 |
+
Select keyframes from a directory of extracted frames at specified intervals
|
363 |
+
|
364 |
+
Args:
|
365 |
+
temp_dir: Directory containing extracted frames (named as 001.jpg, 002.jpg, etc.)
|
366 |
+
kf_every: Select every Nth frame as a keyframe
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
List[str]: Sorted list of paths to selected keyframe images
|
370 |
+
"""
|
371 |
+
# Get all jpg files in the directory
|
372 |
+
frame_paths = glob.glob(os.path.join(temp_dir, "*.jpg"))
|
373 |
+
|
374 |
+
# Sort frames by number to ensure correct order
|
375 |
+
frame_paths.sort(key=lambda x: int(Path(x).stem))
|
376 |
+
|
377 |
+
# Select keyframes at specified interval
|
378 |
+
keyframe_paths = frame_paths[::kf_every]
|
379 |
+
|
380 |
+
# Ensure we have at least 2 frames for reconstruction
|
381 |
+
if len(keyframe_paths) < 2:
|
382 |
+
if len(frame_paths) >= 2:
|
383 |
+
# If we have at least 2 frames, use first and last
|
384 |
+
keyframe_paths = [frame_paths[0], frame_paths[-1]]
|
385 |
+
else:
|
386 |
+
raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.")
|
387 |
+
|
388 |
+
return keyframe_paths
|
389 |
+
|
390 |
+
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
391 |
+
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
392 |
+
from dust3r.utils.image import load_images
|
393 |
+
from dust3r.image_pairs import make_pairs
|
394 |
+
from dust3r.utils.device import to_numpy
|
395 |
+
def invert_matrix(mat):
|
396 |
+
"""Invert a torch or numpy matrix."""
|
397 |
+
if isinstance(mat, torch.Tensor):
|
398 |
+
return torch.linalg.inv(mat)
|
399 |
+
if isinstance(mat, np.ndarray):
|
400 |
+
return np.linalg.inv(mat)
|
401 |
+
raise ValueError(f'Unsupported matrix type: {type(mat)}')
|
402 |
+
|
403 |
+
def refine(
|
404 |
+
video_path: str,
|
405 |
+
conf_thresh: float = 5.0,
|
406 |
+
kf_every: int = 30,
|
407 |
+
remove_background: bool = False,
|
408 |
+
enable_registration: bool = True,
|
409 |
+
output_3d_model: bool = True
|
410 |
+
) -> dict:
|
411 |
+
# Extract keyframes from video
|
412 |
+
temp_dir = extract_frames(video_path)
|
413 |
+
keyframe_paths = get_keyframes(temp_dir, kf_every*3)
|
414 |
+
|
415 |
+
image_size = 512
|
416 |
+
images = load_images(keyframe_paths, size=image_size)
|
417 |
+
|
418 |
+
# Create output directory
|
419 |
+
output_dir = tempfile.mkdtemp()
|
420 |
+
|
421 |
+
# Generate pairs and run inference
|
422 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
423 |
+
cache_dir = os.path.join(output_dir, 'cache')
|
424 |
+
if os.path.exists(cache_dir):
|
425 |
+
os.system(f'rm -rf {cache_dir}')
|
426 |
+
scene = sparse_global_alignment(keyframe_paths, pairs, cache_dir,
|
427 |
+
mast3r_model, lr1=0.07, niter1=500, lr2=0.014,
|
428 |
+
niter2=200 if enable_registration else 0, device=DEFAULT_DEVICE,
|
429 |
+
opt_depth=True if enable_registration else False, shared_intrinsics=True,
|
430 |
+
matching_conf_thr=5.)
|
431 |
+
|
432 |
+
# Extract scene information
|
433 |
+
imgs = np.array(scene.imgs)
|
434 |
+
|
435 |
+
tsdf = TSDFPostProcess(scene, TSDF_thresh=0)
|
436 |
+
pts3d, _, confs = tsdf.get_dense_pts3d(clean_depth=True)
|
437 |
+
masks = np.array(to_numpy([c > 1.5 for c in confs]))
|
438 |
+
|
439 |
+
pcds = []
|
440 |
+
for pts, conf_mask, image in zip(pts3d, masks, imgs):
|
441 |
+
if remove_background:
|
442 |
+
mask = generate_mask(image)
|
443 |
+
else:
|
444 |
+
mask = np.ones_like(conf_mask)
|
445 |
+
combined_mask = conf_mask & (mask > 0.5)
|
446 |
+
|
447 |
+
pts = pts.reshape(combined_mask.shape[0], combined_mask.shape[1], 3)
|
448 |
+
pts_normal = pts2normal(pts).cpu().numpy()
|
449 |
+
pts = pts.cpu().numpy()
|
450 |
+
pcd = o3d.geometry.PointCloud()
|
451 |
+
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask] / 5)
|
452 |
+
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
453 |
+
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
454 |
+
pcds.append(pcd)
|
455 |
+
|
456 |
+
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
457 |
+
o3d_geometry = point2mesh(pcd_combined, depth=9)
|
458 |
+
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
|
459 |
+
|
460 |
+
# Create coarse result
|
461 |
+
coarse_output_path = export_geometry(o3d_geometry_centered)
|
462 |
+
|
463 |
+
if output_3d_model:
|
464 |
+
gs_output_path = tempfile.mktemp(suffix='.ply')
|
465 |
+
point2gs(gs_output_path, pcd_combined)
|
466 |
+
return coarse_output_path, [gs_output_path]
|
467 |
+
else:
|
468 |
+
pcd_output_path = export_geometry(pcd_combined, file_format='ply')
|
469 |
+
return coarse_output_path, [pcd_output_path]
|
470 |
+
|
471 |
@torch.no_grad()
|
472 |
def reconstruct(video_path, conf_thresh, kf_every,
|
473 |
remove_background=False, enable_registration=True, output_3d_model=True):
|
|
|
493 |
|
494 |
# Process results
|
495 |
pcds = []
|
496 |
+
poses_all = []
|
497 |
cameras_all = []
|
498 |
+
images_all = []
|
499 |
+
masks_all = []
|
500 |
last_focal = None
|
501 |
+
|
502 |
+
##### estimate focal length
|
503 |
+
_, H, W, _ = preds[0]['pts3d'].shape
|
504 |
+
pp = torch.tensor((W/2, H/2))
|
505 |
+
focal = estimate_focal_knowing_depth(preds[0]['pts3d'].cpu(), pp, focal_mode='weiszfeld')
|
506 |
+
print(f'Estimated focal of first camera: {focal.item()} (224x224)')
|
507 |
+
|
508 |
+
intrinsic = np.eye(3)
|
509 |
+
intrinsic[0, 0] = focal
|
510 |
+
intrinsic[1, 1] = focal
|
511 |
+
intrinsic[:2, 2] = pp
|
512 |
+
|
513 |
for j, view in enumerate(batch):
|
514 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
515 |
image = (image + 1) / 2
|
516 |
+
mask = view['valid_mask'].cpu().numpy()[0]
|
517 |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
|
518 |
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
|
519 |
+
|
520 |
+
##### Solve PnP-RANSAC
|
521 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
522 |
+
points_2d = np.stack((u, v), axis=-1)
|
523 |
+
dist_coeffs = np.zeros(4).astype(np.float32)
|
524 |
+
success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac(
|
525 |
+
pts.reshape(-1, 3).astype(np.float32),
|
526 |
+
points_2d.reshape(-1, 2).astype(np.float32),
|
527 |
+
intrinsic.astype(np.float32),
|
528 |
+
dist_coeffs)
|
529 |
+
|
530 |
+
rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
|
531 |
+
# Extrinsic parameters (4x4 matrix)
|
532 |
+
extrinsic_matrix = np.hstack((rotation_matrix, translation_vector.reshape(-1, 1)))
|
533 |
+
extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1]))
|
534 |
+
poses_all.append(np.linalg.inv(extrinsic_matrix))
|
535 |
+
|
536 |
conf = preds[j]['conf'][0].cpu().data.numpy()
|
537 |
conf_sig = (conf - 1) / conf
|
538 |
if remove_background:
|
|
|
550 |
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
551 |
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
552 |
pcds.append(pcd)
|
553 |
+
images_all.append(image)
|
554 |
+
masks_all.append(mask)
|
555 |
cameras_all.append(camera)
|
|
|
556 |
|
557 |
+
transform_dict = get_transform_json(H, W, focal, poses_all)
|
558 |
+
temp_json_file = tempfile.mktemp(suffix='.json')
|
559 |
+
with open(os.path.join(temp_json_file), 'w') as f:
|
560 |
+
json.dump(transform_dict, f, indent=4)
|
561 |
+
|
562 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
563 |
o3d_geometry = point2mesh(pcd_combined)
|
564 |
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
|
|
|
570 |
pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
571 |
pcd_combined = center_pcd(pcd_combined)
|
572 |
|
573 |
+
# zip_path = organize_and_zip_output(images_all, masks_all, temp_json_file)
|
574 |
if output_3d_model:
|
575 |
gs_output_path = tempfile.mktemp(suffix='.ply')
|
576 |
point2gs(gs_output_path, pcd_combined)
|
577 |
+
return coarse_output_path, [gs_output_path]
|
|
|
578 |
else:
|
579 |
pcd_output_path = export_geometry(pcd_combined, file_format='ply')
|
580 |
+
return coarse_output_path, [pcd_output_path]
|
|
|
|
|
|
|
581 |
|
582 |
example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
|
583 |
|
|
|
661 |
info="Generate Splat (PLY) instead of Point Cloud (PLY)"
|
662 |
)
|
663 |
reconstruct_btn = gr.Button("Start Reconstruction")
|
664 |
+
refine_btn = gr.Button("Start Refinement")
|
665 |
|
666 |
with gr.Column(scale=2):
|
667 |
with gr.Tab("3D Models"):
|
|
|
673 |
)
|
674 |
|
675 |
with gr.Group():
|
676 |
+
output_model = gr.File(
|
677 |
+
label="Reconstructed Results",
|
|
|
|
|
678 |
)
|
679 |
|
680 |
Examples(
|
|
|
694 |
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
|
695 |
outputs=[initial_model, output_model]
|
696 |
)
|
697 |
+
|
698 |
+
refine_btn.click(
|
699 |
+
fn=refine,
|
700 |
+
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
|
701 |
+
outputs=[initial_model, output_model]
|
702 |
+
)
|
703 |
+
|
704 |
|
705 |
if __name__ == "__main__":
|
706 |
iface.launch(server_name="0.0.0.0")
|