Spaces:
Runtime error
Runtime error
| import pycolmap | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| from src.model.encoder.vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map | |
| from lightglue import ALIKED, SuperPoint, SIFT | |
| from src.utils.tensor_to_pycolmap import batch_matrix_to_pycolmap, pycolmap_to_batch_matrix | |
| _RESNET_MEAN = [0.485, 0.456, 0.406] | |
| _RESNET_STD = [0.229, 0.224, 0.225] | |
| def generate_rank_by_dino( | |
| images, query_frame_num, image_size=518, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=True | |
| ): | |
| """ | |
| Generate a ranking of frames using DINO ViT features. | |
| Args: | |
| images: Tensor of shape (S, 3, H, W) with values in range [0, 1] | |
| query_frame_num: Number of frames to select | |
| image_size: Size to resize images to before processing | |
| model_name: Name of the DINO model to use | |
| device: Device to run the model on | |
| spatial_similarity: Whether to use spatial token similarity or CLS token similarity | |
| Returns: | |
| List of frame indices ranked by their representativeness | |
| """ | |
| dino_v2_model = torch.hub.load('facebookresearch/dinov2', model_name) | |
| dino_v2_model.eval() | |
| dino_v2_model = dino_v2_model.to(device) | |
| resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) | |
| resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) | |
| images_resnet_norm = (images - resnet_mean) / resnet_std | |
| with torch.no_grad(): | |
| frame_feat = dino_v2_model(images_resnet_norm, is_training=True) | |
| if spatial_similarity: | |
| frame_feat = frame_feat["x_norm_patchtokens"] | |
| frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) | |
| # Compute the similarity matrix | |
| frame_feat_norm = frame_feat_norm.permute(1, 0, 2) | |
| similarity_matrix = torch.bmm( | |
| frame_feat_norm, frame_feat_norm.transpose(-1, -2) | |
| ) | |
| similarity_matrix = similarity_matrix.mean(dim=0) | |
| else: | |
| frame_feat = frame_feat["x_norm_clstoken"] | |
| frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) | |
| similarity_matrix = torch.mm( | |
| frame_feat_norm, frame_feat_norm.transpose(-1, -2) | |
| ) | |
| distance_matrix = 100 - similarity_matrix.clone() | |
| # Ignore self-pairing | |
| similarity_matrix.fill_diagonal_(-100) | |
| similarity_sum = similarity_matrix.sum(dim=1) | |
| # Find the most common frame | |
| most_common_frame_index = torch.argmax(similarity_sum).item() | |
| # Conduct FPS sampling starting from the most common frame | |
| fps_idx = farthest_point_sampling( | |
| distance_matrix, query_frame_num, most_common_frame_index | |
| ) | |
| return fps_idx | |
| def farthest_point_sampling( | |
| distance_matrix, num_samples, most_common_frame_index=0 | |
| ): | |
| """ | |
| Farthest point sampling algorithm to select diverse frames. | |
| Args: | |
| distance_matrix: Matrix of distances between frames | |
| num_samples: Number of frames to select | |
| most_common_frame_index: Index of the first frame to select | |
| Returns: | |
| List of selected frame indices | |
| """ | |
| distance_matrix = distance_matrix.clamp(min=0) | |
| N = distance_matrix.size(0) | |
| # Initialize with the most common frame | |
| selected_indices = [most_common_frame_index] | |
| check_distances = distance_matrix[selected_indices] | |
| while len(selected_indices) < num_samples: | |
| # Find the farthest point from the current set of selected points | |
| farthest_point = torch.argmax(check_distances) | |
| selected_indices.append(farthest_point.item()) | |
| check_distances = distance_matrix[farthest_point] | |
| # Mark already selected points to avoid selecting them again | |
| check_distances[selected_indices] = 0 | |
| # Break if all points have been selected | |
| if len(selected_indices) == N: | |
| break | |
| return selected_indices | |
| def calculate_index_mappings(query_index, S, device=None): | |
| """ | |
| Construct an order that switches [query_index] and [0] | |
| so that the content of query_index would be placed at [0]. | |
| Args: | |
| query_index: Index to swap with 0 | |
| S: Total number of elements | |
| device: Device to place the tensor on | |
| Returns: | |
| Tensor of indices with the swapped order | |
| """ | |
| new_order = torch.arange(S) | |
| new_order[0] = query_index | |
| new_order[query_index] = 0 | |
| if device is not None: | |
| new_order = new_order.to(device) | |
| return new_order | |
| def switch_tensor_order(tensors, order, dim=1): | |
| """ | |
| Reorder tensors along a specific dimension according to the given order. | |
| Args: | |
| tensors: List of tensors to reorder | |
| order: Tensor of indices specifying the new order | |
| dim: Dimension along which to reorder | |
| Returns: | |
| List of reordered tensors | |
| """ | |
| return [ | |
| torch.index_select(tensor, dim, order) if tensor is not None else None | |
| for tensor in tensors | |
| ] | |
| def predict_track(model, images, query_points, dtype=torch.bfloat16, use_tf32_for_track=True, iters=4): | |
| """ | |
| Predict tracks for query points across frames. | |
| Args: | |
| model: VGGT model | |
| images: Tensor of images of shape (S, 3, H, W) | |
| query_points: Query points to track | |
| dtype: Data type for computation | |
| use_tf32_for_track: Whether to use TF32 precision for tracking | |
| iters: Number of iterations for tracking | |
| Returns: | |
| Predicted tracks, visibility scores, and confidence scores | |
| """ | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=dtype): | |
| images = images[None] # add batch dimension | |
| aggregated_tokens_list, ps_idx = model.aggregator(images) | |
| if not use_tf32_for_track: | |
| track_list, vis_score, conf_score = model.track_head( | |
| aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters | |
| ) | |
| if use_tf32_for_track: | |
| with torch.cuda.amp.autocast(enabled=False): | |
| track_list, vis_score, conf_score = model.track_head( | |
| aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters | |
| ) | |
| pred_track = track_list[-1] | |
| return pred_track.squeeze(0), vis_score.squeeze(0), conf_score.squeeze(0) | |
| def initialize_feature_extractors(max_query_num, det_thres, extractor_method="aliked", device="cuda"): | |
| """ | |
| Initialize feature extractors that can be reused based on a method string. | |
| Args: | |
| max_query_num: Maximum number of keypoints to extract | |
| det_thres: Detection threshold for keypoint extraction | |
| extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") | |
| device: Device to run extraction on | |
| Returns: | |
| Dictionary of initialized extractors | |
| """ | |
| extractors = {} | |
| methods = extractor_method.lower().split('+') | |
| active_extractors = len(methods) | |
| for method in methods: | |
| method = method.strip() | |
| if method == "aliked": | |
| aliked_max_points = max_query_num // active_extractors | |
| aliked_extractor = ALIKED(max_num_keypoints=aliked_max_points, detection_threshold=det_thres) | |
| extractors['aliked'] = aliked_extractor.to(device).eval() | |
| elif method == "sp": | |
| sp_max_points = max_query_num // active_extractors | |
| sp_extractor = SuperPoint(max_num_keypoints=sp_max_points, detection_threshold=det_thres) | |
| extractors['sp'] = sp_extractor.to(device).eval() | |
| elif method == "sift": | |
| sift_max_points = max_query_num // active_extractors | |
| sift_extractor = SIFT(max_num_keypoints=sift_max_points) | |
| extractors['sift'] = sift_extractor.to(device).eval() | |
| else: | |
| print(f"Warning: Unknown feature extractor '{method}', ignoring.") | |
| if not extractors: | |
| print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") | |
| aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) | |
| extractors['aliked'] = aliked_extractor.to(device).eval() | |
| return extractors | |
| def extract_keypoints(query_image, extractors): | |
| """ | |
| Extract keypoints using pre-initialized feature extractors. | |
| Args: | |
| query_image: Input image tensor (3xHxW, range [0, 1]) | |
| extractors: Dictionary of initialized extractors | |
| Returns: | |
| Tensor of keypoint coordinates (1xNx2) | |
| """ | |
| query_points_round = None | |
| with torch.no_grad(): | |
| for extractor_name, extractor in extractors.items(): | |
| query_points_data = extractor.extract(query_image) | |
| extractor_points = query_points_data["keypoints"].round() | |
| if query_points_round is not None: | |
| query_points_round = torch.cat([query_points_round, extractor_points], dim=1) | |
| else: | |
| query_points_round = extractor_points | |
| return query_points_round | |
| def run_vggt_with_ba(model, images, image_names=None, dtype=torch.bfloat16, | |
| max_query_num=2048, det_thres=0.005, query_frame_num=3, | |
| extractor_method="aliked+sp+sift", | |
| max_reproj_error=4, | |
| shared_camera=True, | |
| ): | |
| """ | |
| Run VGGT with bundle adjustment for pose estimation. | |
| Args: | |
| model: VGGT model | |
| images: Tensor of images of shape (S, 3, H, W) | |
| image_names: Optional list of image names | |
| dtype: Data type for computation | |
| Returns: | |
| Predicted extrinsic camera parameters | |
| TODO: | |
| - [ ] Use VGGT's vit instead of dinov2 for rank generation | |
| """ | |
| device = images.device | |
| frame_num = images.shape[0] | |
| # TODO: use vggt's vit instead of dinov2 | |
| # Select representative frames for feature extraction | |
| query_frame_indexes = generate_rank_by_dino( | |
| images, query_frame_num, image_size=518, | |
| model_name="dinov2_vitb14_reg", device=device, | |
| spatial_similarity=False | |
| ) | |
| # Add the first image to the front if not already present | |
| if 0 in query_frame_indexes: | |
| query_frame_indexes.remove(0) | |
| query_frame_indexes = [0, *query_frame_indexes] | |
| # Get initial pose and depth predictions | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): | |
| aggregated_tokens_list, patch_start_idx = model.aggregator(images, intermediate_layer_idx=model.cfg.intermediate_layer_idx) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] | |
| pred_all_pose_enc = model.camera_head(fp32_tokens)[-1] | |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) | |
| pred_extrinsic = pred_all_extrinsic[0] | |
| pred_intrinsic = pred_all_intrinsic[0] | |
| depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx) | |
| world_points = unproject_depth_map_to_point_map(depth_map, pred_extrinsic, pred_intrinsic) | |
| world_points = torch.from_numpy(world_points).to(device) | |
| world_points_conf = depth_conf.to(device) | |
| torch.cuda.empty_cache() | |
| # Lists to store predictions | |
| pred_tracks = [] | |
| pred_vis_scores = [] | |
| pred_conf_scores = [] | |
| pred_world_points = [] | |
| pred_world_points_conf = [] | |
| # Initialize feature extractors | |
| extractors = initialize_feature_extractors(max_query_num, det_thres, extractor_method, device) | |
| # Process each query frame | |
| for query_index in query_frame_indexes: | |
| query_image = images[query_index] | |
| query_points_round = extract_keypoints(query_image, extractors) | |
| # Reorder images to put query image first | |
| reorder_index = calculate_index_mappings(query_index, frame_num, device=device) | |
| reorder_images = switch_tensor_order([images], reorder_index, dim=0)[0] | |
| # Track points across frames | |
| reorder_tracks, reorder_vis_score, reorder_conf_score = predict_track( | |
| model, reorder_images, query_points_round, dtype=dtype, use_tf32_for_track=True, iters=4 | |
| ) | |
| # Restore original order | |
| pred_track, pred_vis, pred_score = switch_tensor_order( | |
| [reorder_tracks, reorder_vis_score, reorder_conf_score], reorder_index, dim=0 | |
| ) | |
| pred_tracks.append(pred_track) | |
| pred_vis_scores.append(pred_vis) | |
| pred_conf_scores.append(pred_score) | |
| # Get corresponding 3D points | |
| query_points_round_long = query_points_round.squeeze(0).long() | |
| query_world_points = world_points[query_index][ | |
| query_points_round_long[:, 1], query_points_round_long[:, 0] | |
| ] | |
| query_world_points_conf = world_points_conf[query_index][ | |
| query_points_round_long[:, 1], query_points_round_long[:, 0] | |
| ] | |
| pred_world_points.append(query_world_points) | |
| pred_world_points_conf.append(query_world_points_conf) | |
| # Concatenate prediction lists | |
| pred_tracks = torch.cat(pred_tracks, dim=1) | |
| pred_vis_scores = torch.cat(pred_vis_scores, dim=1) | |
| pred_conf_scores = torch.cat(pred_conf_scores, dim=1) | |
| pred_world_points = torch.cat(pred_world_points, dim=0) | |
| pred_world_points_conf = torch.cat(pred_world_points_conf, dim=0) | |
| # Filter points by confidence | |
| filtered_flag = pred_world_points_conf > 1.5 | |
| if filtered_flag.sum() > 1024: | |
| # well if the number of points is too small, we will not filter | |
| pred_world_points = pred_world_points[filtered_flag] | |
| pred_world_points_conf = pred_world_points_conf[filtered_flag] | |
| pred_tracks = pred_tracks[:, filtered_flag] | |
| pred_vis_scores = pred_vis_scores[:, filtered_flag] | |
| pred_conf_scores = pred_conf_scores[:, filtered_flag] | |
| torch.cuda.empty_cache() | |
| # Bundle adjustment parameters | |
| S, _, H, W = images.shape | |
| image_size = torch.tensor([W, H], dtype=pred_tracks.dtype, device=device) | |
| # Run bundle adjustment | |
| reconstruction = batch_matrix_to_pycolmap( | |
| pred_world_points, | |
| pred_extrinsic, | |
| pred_intrinsic, | |
| pred_tracks, | |
| image_size, | |
| max_reproj_error=max_reproj_error, | |
| shared_camera=shared_camera | |
| ) | |
| ba_options = pycolmap.BundleAdjustmentOptions() | |
| pycolmap.bundle_adjustment(reconstruction, ba_options) | |
| _, updated_extrinsic, _, _ = pycolmap_to_batch_matrix( | |
| reconstruction, device=device, camera_type="SIMPLE_PINHOLE" | |
| ) | |
| return updated_extrinsic | |