| import os |
| import sys |
| import json |
| import gzip |
| import argparse |
| import numpy as np |
| from PIL import Image |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision |
| from einops import rearrange |
| from lpips import LPIPS |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from src.model.model.anysplat import AnySplat |
| from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| from src.model.encoder.vggt.utils.load_fn import load_and_preprocess_images |
| from src.utils.pose import align_to_first_camera, calculate_auc_np, convert_pt3d_RT_to_opencv, se3_to_relative_pose_error |
| from src.misc.cam_utils import camera_normalization, pose_auc, rotation_6d_to_matrix, update_pose, get_pnp_pose |
|
|
| def setup_args(): |
| """Set up command-line arguments for the CO3D evaluation script.""" |
| parser = argparse.ArgumentParser(description='Test AnySplat on CO3D dataset') |
| parser.add_argument('--debug', action='store_true', help='Enable debug mode (only test on specific category)') |
| parser.add_argument('--use_ba', action='store_true', default=False, help='Enable bundle adjustment') |
| parser.add_argument('--fast_eval', action='store_true', default=False, help='Only evaluate 10 sequences per category') |
| parser.add_argument('--min_num_images', type=int, default=50, help='Minimum number of images for a sequence') |
| parser.add_argument('--num_frames', type=int, default=10, help='Number of frames to use for testing') |
| parser.add_argument('--co3d_dir', type=str, required=True, help='Path to CO3D dataset') |
| parser.add_argument('--co3d_anno_dir', type=str, required=True, help='Path to CO3D annotations') |
| parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility') |
| return parser.parse_args() |
|
|
| lpips = LPIPS(net="vgg") |
|
|
| def rendering_loss(pred_image, image): |
| lpips_loss = lpips.forward(rearrange(pred_image, "b v c h w -> (b v) c h w"), rearrange(image, "b v c h w -> (b v) c h w"), normalize=True) |
| delta = pred_image - (image + 1) / 2 |
| mse_loss = (delta**2).mean() |
| return mse_loss + 0.05 * lpips_loss.mean() |
|
|
| def process_sequence(model, seq_name, seq_data, category, co3d_dir, min_num_images, num_frames, use_ba, device, dtype): |
| """ |
| Process a single sequence and compute pose errors. |
| |
| Args: |
| model: AnySplat model |
| seq_name: Sequence name |
| seq_data: Sequence data |
| category: Category name |
| co3d_dir: CO3D dataset directory |
| min_num_images: Minimum number of images required |
| num_frames: Number of frames to sample |
| use_ba: Whether to use bundle adjustment |
| device: Device to run on |
| dtype: Data type for model inference |
| |
| Returns: |
| rError: Rotation errors |
| tError: Translation errors |
| """ |
| if len(seq_data) < min_num_images: |
| return None, None |
| |
| metadata = [] |
| for data in seq_data: |
| |
| if data["T"][0] + data["T"][1] + data["T"][2] > 1e5: |
| return None, None |
|
|
| extri_opencv = convert_pt3d_RT_to_opencv(data["R"], data["T"]) |
| metadata.append({ |
| "filepath": data["filepath"], |
| "extri": extri_opencv, |
| }) |
|
|
| ids = np.random.choice(len(metadata), num_frames, replace=False) |
| image_names = [os.path.join(co3d_dir, metadata[i]["filepath"]) for i in ids] |
| gt_extri = [np.array(metadata[i]["extri"]) for i in ids] |
| gt_extri = np.stack(gt_extri, axis=0) |
| |
| max_size = max(Image.open(image_names[0]).size) |
| if max_size < 448: |
| return None, None |
| images = load_and_preprocess_images(image_names)[None].to(device) |
| |
| batch = { |
| "context": { |
| "image": images*2.0-1, |
| "image_names": image_names, |
| "index": ids, |
| }, |
| "scene": "co3d" |
| } |
| |
| if use_ba: |
| try: |
| encoder_output = model.encoder( |
| batch, |
| global_step=0, |
| visualization_dump={}, |
| ) |
| gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
| pred_extrinsic = pred_context_pose['extrinsic'] |
| pred_intrinsic = pred_context_pose['intrinsic'] |
| |
| b, v, _, h, w = images.shape |
| with torch.set_grad_enabled(True), torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): |
| cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32)) |
| cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32)) |
| opt_params = [] |
| model.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=torch.float32).to(pred_extrinsic.device)) |
| opt_params.append( |
| { |
| "params": [cam_rot_delta], |
| "lr": 0.005, |
| } |
| ) |
| opt_params.append( |
| { |
| "params": [cam_trans_delta], |
| "lr": 0.005, |
| } |
| ) |
| pose_optimizer = torch.optim.Adam(opt_params) |
| extrinsics = pred_extrinsic.clone().float() |
|
|
| for i in range(100): |
| pose_optimizer.zero_grad() |
| dx, drot = cam_trans_delta, cam_rot_delta |
| rot = rotation_6d_to_matrix( |
| drot + model.identity.expand(b, v, -1) |
| ) |
|
|
| transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1)) |
| transform[..., :3, :3] = rot |
| transform[..., :3, 3] = dx |
|
|
| new_extrinsics = torch.matmul(extrinsics, transform) |
| |
| output = model.decoder.forward( |
| gaussians, |
| new_extrinsics, |
| pred_intrinsic.float(), |
| 0.1, |
| 100.0, |
| (h, w), |
| |
| |
| ) |
| |
| rendering_loss = rendering_loss(output.color, images*2.0-1) |
| torchvision.utils.save_image(output.color[0], f"outputs/vis/output_co3d_{i}.png") |
| print(f"Rendering loss: {rendering_loss.item()}") |
| |
|
|
| rendering_loss.backward() |
| pose_optimizer.step() |
| torchvision.utils.save_image(images[0], f"outputs/vis/gt_co3d.png") |
| pred_extrinsic = new_extrinsics.inverse()[0][:,:-1,:] |
|
|
| except Exception as e: |
| print(f"BA failed with error: {e}. Falling back to standard VGGT inference.") |
| with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): |
| aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| with torch.cuda.amp.autocast(dtype=torch.float32): |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
| pred_extrinsic = pred_all_extrinsic[0] |
| else: |
| with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): |
| aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| with torch.cuda.amp.autocast(dtype=torch.float32): |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
| pred_extrinsic = pred_all_extrinsic[0] |
|
|
| with torch.cuda.amp.autocast(dtype=torch.float32): |
| gt_extrinsic = torch.from_numpy(gt_extri).to(device) |
| add_row = torch.tensor([0, 0, 0, 1], device=device).expand(pred_extrinsic.size(0), 1, 4) |
|
|
| pred_se3 = torch.cat((pred_extrinsic, add_row), dim=1) |
| gt_se3 = torch.cat((gt_extrinsic, add_row), dim=1) |
|
|
| |
| |
| |
| gt_se3 = align_to_first_camera(gt_se3) |
|
|
| rel_rangle_deg, rel_tangle_deg = se3_to_relative_pose_error(pred_se3, gt_se3, num_frames) |
| print(f"{category} sequence {seq_name} Rot Error: {rel_rangle_deg.mean().item():.4f}") |
| print(f"{category} sequence {seq_name} Trans Error: {rel_tangle_deg.mean().item():.4f}") |
|
|
| return rel_rangle_deg.cpu().numpy(), rel_tangle_deg.cpu().numpy() |
|
|
| def evaluate(args: argparse.Namespace): |
| model = AnySplat.from_pretrained("lhjiang/anysplat") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| |
| SEEN_CATEGORIES = [ |
| "apple", "backpack", "banana", "baseballbat", "baseballglove", |
| "bench", "bicycle", "bottle", "bowl", "broccoli", |
| "cake", "car", "carrot", "cellphone", "chair", |
| "cup", "donut", "hairdryer", "handbag", "hydrant", |
| "keyboard", "laptop", "microwave", "motorcycle", "mouse", |
| "orange", "parkingmeter", "pizza", "plant", "stopsign", |
| "teddybear", "toaster", "toilet", "toybus", "toyplane", |
| "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", |
| ] |
| |
| if args.debug: |
| SEEN_CATEGORIES = ["apple"] |
| |
| per_category_results = {} |
|
|
| for category in SEEN_CATEGORIES: |
| print(f"Loading annotation for {category} test set") |
| annotation_file = os.path.join(args.co3d_anno_dir, f"{category}_test.jgz") |
| |
| try: |
| with gzip.open(annotation_file, "r") as fin: |
| annotation = json.loads(fin.read()) |
| except FileNotFoundError: |
| print(f"Annotation file not found for {category}, skipping") |
| continue |
| |
| rError = [] |
| tError = [] |
|
|
| for seq_name, seq_data in annotation.items(): |
| print("-" * 50) |
| |
| print(f"Processing {seq_name} for {category} test set") |
| if args.debug and not os.path.exists(os.path.join(args.co3d_dir, category, seq_name)): |
| print(f"Skipping {seq_name} (not found)") |
| continue |
| |
| seq_rError, seq_tError = process_sequence( |
| model, seq_name, seq_data, category, args.co3d_dir, |
| args.min_num_images, args.num_frames, args.use_ba, device, torch.bfloat16 |
| ) |
| |
| print("-" * 50) |
| |
| if seq_rError is not None and seq_tError is not None: |
| rError.extend(seq_rError) |
| tError.extend(seq_tError) |
|
|
| if not rError: |
| print(f"No valid sequences found for {category}, skipping") |
| continue |
|
|
| rError = np.array(rError) |
| tError = np.array(tError) |
| |
| thresholds = [5, 10, 20, 30] |
| Aucs = {} |
| |
| for threshold in thresholds: |
| Auc, _ = calculate_auc_np(rError, tError, max_threshold=threshold) |
| Aucs[threshold] = Auc |
| |
| print("="*80) |
| print(f"AUC of {category} test set: {Aucs[30]:.4f}") |
| print("="*80) |
| |
| per_category_results[category] = { |
| "rError": rError, |
| "tError": tError, |
| "Auc_5": Aucs[5], |
| "Auc_10": Aucs[10], |
| "Auc_20": Aucs[20], |
| "Auc_30": Aucs[30], |
| } |
|
|
| |
| print("\nSummary of AUC results:") |
| print("-"*50) |
| for category in sorted(per_category_results.keys()): |
| print(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}") |
| print(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}") |
| print(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}") |
| print(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}") |
|
|
| if per_category_results: |
| mean_AUC_30 = np.mean([per_category_results[category]["Auc_30"] for category in per_category_results]) |
| mean_AUC_20 = np.mean([per_category_results[category]["Auc_20"] for category in per_category_results]) |
| mean_AUC_10 = np.mean([per_category_results[category]["Auc_10"] for category in per_category_results]) |
| mean_AUC_5 = np.mean([per_category_results[category]["Auc_5"] for category in per_category_results]) |
| print("-"*50) |
| print(f"Mean AUC_5: {mean_AUC_5:.4f}") |
| print(f"Mean AUC_30: {mean_AUC_30:.4f}") |
| print(f"Mean AUC_20: {mean_AUC_20:.4f}") |
| print(f"Mean AUC_10: {mean_AUC_10:.4f}") |
| |
| |
| |
| |
| import datetime |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| random_index = timestamp |
| results_file = f"co3d_results_{random_index}.txt" |
|
|
| with open(results_file, "w") as f: |
| f.write("CO3D Evaluation Results\n") |
| f.write("=" * 50 + "\n\n") |
| |
| f.write("Per-category results:\n") |
| f.write("-" * 50 + "\n") |
| for category in sorted(per_category_results.keys()): |
| f.write(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}\n") |
| f.write(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}\n") |
| f.write(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}\n") |
| f.write(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}\n") |
| f.write("\n") |
| |
| if per_category_results: |
| f.write("-" * 50 + "\n") |
| f.write(f"Mean AUC_30: {mean_AUC_30:.4f}\n") |
| f.write(f"Mean AUC_20: {mean_AUC_20:.4f}\n") |
| f.write(f"Mean AUC_10: {mean_AUC_10:.4f}\n") |
| f.write(f"Mean AUC_5: {mean_AUC_5:.4f}\n") |
| f.write("\n" + "=" * 50 + "\n") |
| |
| print(f"Results saved to {results_file}") |
|
|
|
|
| if __name__ == "__main__": |
| args = setup_args() |
| evaluate(args) |
|
|