Spaces:
Running
on
Zero
Running
on
Zero
| import matplotlib | |
| matplotlib.use('Agg') | |
| import sys | |
| import argparse | |
| import torch | |
| import numpy as np | |
| import os | |
| import re | |
| import cv2 | |
| import trimesh | |
| from pathlib import Path | |
| from PIL import Image | |
| from typing import NamedTuple, Optional | |
| current_dir = os.getcwd() | |
| sys.path.append(os.path.join(current_dir, 'mast3r')) | |
| from mast3r.model import AsymmetricMASt3R | |
| from mast3r.cloud_opt.sparse_ga import sparse_global_alignment | |
| from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess | |
| import mast3r.utils.path_to_dust3r | |
| from dust3r.utils.image import load_images | |
| from dust3r.utils.device import to_numpy | |
| from dust3r.image_pairs import make_pairs | |
| from plyfile import PlyData, PlyElement | |
| from utils.dust3r_utils import visualizer, pca, upsampler | |
| class BasicPointCloud(NamedTuple): | |
| points: np.array | |
| colors: np.array | |
| normals: np.array | |
| def invert_matrix(mat): | |
| """Invert a torch or numpy matrix.""" | |
| if isinstance(mat, torch.Tensor): | |
| return torch.linalg.inv(mat) | |
| if isinstance(mat, np.ndarray): | |
| return np.linalg.inv(mat) | |
| raise ValueError(f'Unsupported matrix type: {type(mat)}') | |
| def fov2focal(fov, pixels): | |
| return pixels / (2 * math.tan(fov / 2)) | |
| def focal2fov(focal, pixels): | |
| return 2 * math.atan(pixels / (2 * focal)) | |
| def rotmat2qvec(R): | |
| Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat | |
| K = np.array([ | |
| [Rxx - Ryy - Rzz, 0, 0, 0], | |
| [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], | |
| [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], | |
| [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 | |
| eigvals, eigvecs = np.linalg.eigh(K) | |
| qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] | |
| if qvec[0] < 0: | |
| qvec *= -1 | |
| return qvec | |
| def storePly(path, xyz, rgb): | |
| # Define the dtype for the structured array | |
| dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), | |
| ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), | |
| ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] | |
| normals = np.zeros_like(xyz) | |
| elements = np.empty(xyz.shape[0], dtype=dtype) | |
| attributes = np.concatenate((xyz, normals, rgb), axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| # Create the PlyData object and write to file | |
| vertex_element = PlyElement.describe(elements, 'vertex') | |
| ply_data = PlyData([vertex_element]) | |
| ply_data.write(path) | |
| # Ensure save directories exist | |
| def init_filestructure(save_path): | |
| save_path.mkdir(exist_ok=True, parents=True) | |
| images_path = save_path / 'images' | |
| masks_path = save_path / 'masks' | |
| sparse_path = save_path / 'sparse/0' | |
| images_path.mkdir(exist_ok=True, parents=True) | |
| masks_path.mkdir(exist_ok=True, parents=True) | |
| sparse_path.mkdir(exist_ok=True, parents=True) | |
| return save_path, images_path, masks_path, sparse_path | |
| # Save images and masks | |
| def save_images_and_masks(imgs, masks, images_path, img_files, masks_path): | |
| for i, (image, name, mask) in enumerate(zip(imgs, img_files, masks)): | |
| imgname = Path(name).stem | |
| image_save_path = images_path / f"{imgname}.png" | |
| mask_save_path = masks_path / f"{imgname}.png" | |
| rgb_image = cv2.cvtColor(image * 255, cv2.COLOR_BGR2RGB) | |
| cv2.imwrite(str(image_save_path), rgb_image) | |
| mask = np.repeat(np.expand_dims(mask, -1), 3, axis=2) * 255 | |
| Image.fromarray(mask.astype(np.uint8)).save(mask_save_path) | |
| # Save camera information | |
| def save_cameras(focals, principal_points, sparse_path, imgs_shape): | |
| cameras_file = sparse_path / 'cameras.txt' | |
| with open(cameras_file, 'w') as f: | |
| f.write("# Camera list with one line of data per camera:\n") | |
| f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") | |
| for i, (focal, pp) in enumerate(zip(focals, principal_points)): | |
| f.write(f"{i} PINHOLE {imgs_shape[2]} {imgs_shape[1]} {focal} {focal} {pp[0]} {pp[1]}\n") | |
| # Save image transformations | |
| def save_images_txt(world2cam, img_files, sparse_path): | |
| images_file = sparse_path / 'images.txt' | |
| with open(images_file, 'w') as f: | |
| f.write("# Image list with two lines of data per image:\n") | |
| f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n") | |
| f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") | |
| for i in range(world2cam.shape[0]): | |
| name = Path(img_files[i]).stem | |
| rotation_matrix = world2cam[i, :3, :3] | |
| qw, qx, qy, qz = rotmat2qvec(rotation_matrix) | |
| tx, ty, tz = world2cam[i, :3, 3] | |
| f.write(f"{i} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {i} {name}.png\n\n") | |
| # Save point cloud with normals | |
| def save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path): | |
| pc = get_point_cloud(imgs, pts3d, masks) | |
| default_normal = [0, 1, 0] | |
| vertices = pc.vertices | |
| colors = pc.colors | |
| normals = np.tile(default_normal, (vertices.shape[0], 1)) | |
| save_path = sparse_path / 'points3D.ply' | |
| header = """ply | |
| format ascii 1.0 | |
| element vertex {} | |
| property float x | |
| property float y | |
| property float z | |
| property uchar red | |
| property uchar green | |
| property uchar blue | |
| property float nx | |
| property float ny | |
| property float nz | |
| end_header | |
| """.format(len(vertices)) | |
| with open(save_path, 'w') as f: | |
| f.write(header) | |
| for vertex, color, normal in zip(vertices, colors, normals): | |
| f.write(f"{vertex[0]} {vertex[1]} {vertex[2]} {int(color[0])} {int(color[1])} {int(color[2])} {normal[0]} {normal[1]} {normal[2]}\n") | |
| # Generate point cloud | |
| def get_point_cloud(imgs, pts3d, mask): | |
| imgs = to_numpy(imgs) | |
| pts3d = to_numpy(pts3d) | |
| mask = to_numpy(mask) | |
| pts = np.concatenate([p[m] for p, m in zip(pts3d, mask.reshape(mask.shape[0], -1))]) | |
| col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) | |
| pts = pts.reshape(-1, 3)[::3] | |
| col = col.reshape(-1, 3)[::3] | |
| normals = np.tile([0, 1, 0], (pts.shape[0], 1)) | |
| pct = trimesh.PointCloud(pts, colors=col) | |
| pct.vertices_normal = normals | |
| return pct | |
| def main(image_dir, save_dir, model_path, device, batch_size, image_size, schedule, lr, niter, min_conf_thr, tsdf_thresh): | |
| # Load model and images | |
| model = AsymmetricMASt3R.from_pretrained(model_path).to(device) | |
| image_files = sorted([str(x) for x in Path(image_dir).iterdir() if x.suffix in ['.png', '.jpg']], | |
| key=lambda x: int(re.search(r'\d+', Path(x).stem).group())) | |
| images = load_images(image_files, size=image_size) | |
| # Generate pairs and run inference | |
| pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True) | |
| cache_dir = os.path.join(save_dir, 'cache') | |
| if os.path.exists(cache_dir): | |
| os.system(f'rm -rf {cache_dir}') | |
| scene = sparse_global_alignment(image_files, pairs, cache_dir, | |
| model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=device, | |
| opt_depth='depth' in 'refine', shared_intrinsics=False, | |
| matching_conf_thr=5.) | |
| # Extract scene information | |
| world2cam = invert_matrix(scene.get_im_poses().detach()).cpu().numpy() | |
| principal_points = scene.get_principal_points().detach().cpu().numpy() | |
| focals = scene.get_focals().detach().cpu().numpy() | |
| imgs = np.array(scene.imgs) | |
| tsdf = TSDFPostProcess(scene, TSDF_thresh=tsdf_thresh) | |
| pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True)) | |
| masks = np.array(to_numpy([c > min_conf_thr for c in confs])) | |
| _, H, W, _ = imgs.shape | |
| feat_dim = 64 | |
| projected_feat = pca(scene.stacked_feat, feat_dim) | |
| upsampled_feat = upsampler(projected_feat, H, W) | |
| visualizer(upsampled_feat, images, save_dir) | |
| # Main execution | |
| save_path, images_path, masks_path, sparse_path = init_filestructure(Path(save_dir)) | |
| save_images_and_masks(imgs, masks, images_path, image_files, masks_path) | |
| save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape) | |
| save_images_txt(world2cam, image_files, sparse_path) | |
| save_pointcloud_with_normals(imgs, pts3d, masks, save_dir) | |
| print(f'[INFO] Mast3R Reconstruction is successfully converted to COLMAP files in: {str(sparse_path)}') | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Process images and save results.') | |
| parser.add_argument('--image_dir', type=str, required=True, help='Directory containing images') | |
| parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the results') | |
| parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint') | |
| parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference') | |
| parser.add_argument('--batch_size', type=int, default=1, help='Batch size for processing images') | |
| parser.add_argument('--image_size', type=int, default=512, help='Size to resize images') | |
| parser.add_argument('--schedule', type=str, default='cosine', help='Learning rate schedule') | |
| parser.add_argument('--lr', type=float, default=0.01, help='Learning rate') | |
| parser.add_argument('--niter', type=int, default=300, help='Number of iterations') | |
| parser.add_argument('--min_conf_thr', type=float, default=1.5, help='Minimum confidence threshold') | |
| parser.add_argument('--tsdf_thresh', type=float, default=0.0, help='TSDF threshold') | |
| args = parser.parse_args() | |
| main(args.image_dir, args.save_dir, args.model_path, args.device, args.batch_size, args.image_size, args.schedule, args.lr, args.niter, args.min_conf_thr, args.tsdf_thresh) |