| import os |
| from pathlib import Path |
| import sys |
| import json |
| import gzip |
| import argparse |
| import numpy as np |
| from PIL import Image |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision |
| from einops import rearrange |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from src.evaluation.metrics import compute_lpips, compute_psnr, compute_ssim |
| from misc.image_io import save_image, save_interpolated_video |
| from src.utils.image import process_image |
|
|
| from src.model.model.anysplat import AnySplat |
| from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
|
| def setup_args(): |
| """Set up command-line arguments for the eval NVS script.""" |
| parser = argparse.ArgumentParser(description='Test AnySplat on NVS evaluation') |
| parser.add_argument('--data_dir', type=str, required=True, help='Path to NVS dataset') |
| parser.add_argument('--llffhold', type=int, default=8, help='LLFF holdout') |
| parser.add_argument('--output_path', type=str, default="outputs/nvs", help='Path to output directory') |
| return parser.parse_args() |
|
|
| def compute_metrics(pred_image, image): |
| psnr = compute_psnr(pred_image, image) |
| ssim = compute_ssim(pred_image, image) |
| lpips = compute_lpips(pred_image, image) |
| return psnr, ssim, lpips |
|
|
| def evaluate(args: argparse.Namespace): |
| model = AnySplat.from_pretrained("lhjiang/anysplat") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
| |
| os.makedirs(args.output_path, exist_ok=True) |
|
|
| |
| image_folder = args.data_dir |
| image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| images = [process_image(img_path) for img_path in image_names] |
| ctx_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold != 0] |
| tgt_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold == 0] |
| |
| ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) |
| tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) |
| ctx_images = (ctx_images+1)*0.5 |
| tgt_images = (tgt_images+1)*0.5 |
| b, v, _, h, w = tgt_images.shape |
|
|
| |
| encoder_output = model.encoder( |
| ctx_images, |
| global_step=0, |
| visualization_dump={}, |
| ) |
| gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
|
|
| num_context_view = ctx_images.shape[1] |
| vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): |
| aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| with torch.cuda.amp.autocast(enabled=False): |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) |
|
|
| extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) |
| pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() |
|
|
| pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w |
| pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h |
| pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] |
| pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] |
|
|
| scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() |
| pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor |
| pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor |
| print("scale_factor:", scale_factor) |
| |
| output = model.decoder.forward( |
| gaussians, |
| pred_all_target_extrinsic, |
| pred_all_target_intrinsic.float(), |
| torch.ones(1, v, device=device) * 0.01, |
| torch.ones(1, v, device=device) * 100, |
| (h, w) |
| ) |
|
|
| save_interpolated_video(pred_all_context_extrinsic, pred_all_context_intrinsic, b, h, w, gaussians, args.output_path, model.decoder) |
| |
| |
| save_path = Path(args.output_path) |
| |
| for idx, (gt_image, pred_image) in enumerate(zip(tgt_images[0], output.color[0])): |
| save_image(gt_image, save_path / "gt" / f"{idx:0>6}.jpg") |
| save_image(pred_image, save_path / "pred" / f"{idx:0>6}.jpg") |
|
|
| |
| psnr, ssim, lpips = compute_metrics(output.color[0], tgt_images[0]) |
| print(f"PSNR: {psnr.mean():.2f}, SSIM: {ssim.mean():.3f}, LPIPS: {lpips.mean():.3f}") |
|
|
| if __name__ == "__main__": |
| args = setup_args() |
| evaluate(args) |
|
|