| | import os |
| | from pathlib import Path |
| | import sys |
| | import json |
| | import gzip |
| | import argparse |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| | from einops import rearrange |
| |
|
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| | from src.evaluation.metrics import compute_lpips, compute_psnr, compute_ssim |
| | from misc.image_io import save_image, save_interpolated_video |
| | from src.utils.image import process_image |
| |
|
| | from src.model.model.anysplat import AnySplat |
| | from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| |
|
| | def setup_args(): |
| | """Set up command-line arguments for the eval NVS script.""" |
| | parser = argparse.ArgumentParser(description='Test AnySplat on NVS evaluation') |
| | parser.add_argument('--data_dir', type=str, required=True, help='Path to NVS dataset') |
| | parser.add_argument('--llffhold', type=int, default=8, help='LLFF holdout') |
| | parser.add_argument('--output_path', type=str, default="outputs/nvs", help='Path to output directory') |
| | return parser.parse_args() |
| |
|
| | def compute_metrics(pred_image, image): |
| | psnr = compute_psnr(pred_image, image) |
| | ssim = compute_ssim(pred_image, image) |
| | lpips = compute_lpips(pred_image, image) |
| | return psnr, ssim, lpips |
| |
|
| | def evaluate(args: argparse.Namespace): |
| | model = AnySplat.from_pretrained("lhjiang/anysplat") |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| | model.eval() |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| | |
| | os.makedirs(args.output_path, exist_ok=True) |
| |
|
| | |
| | image_folder = args.data_dir |
| | image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| | images = [process_image(img_path) for img_path in image_names] |
| | ctx_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold != 0] |
| | tgt_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold == 0] |
| | |
| | ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) |
| | tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) |
| | ctx_images = (ctx_images+1)*0.5 |
| | tgt_images = (tgt_images+1)*0.5 |
| | b, v, _, h, w = tgt_images.shape |
| |
|
| | |
| | encoder_output = model.encoder( |
| | ctx_images, |
| | global_step=0, |
| | visualization_dump={}, |
| | ) |
| | gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
| |
|
| | num_context_view = ctx_images.shape[1] |
| | vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) |
| | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): |
| | aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| | with torch.cuda.amp.autocast(enabled=False): |
| | fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| | pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| | pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) |
| |
|
| | extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) |
| | pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() |
| |
|
| | pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w |
| | pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h |
| | pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] |
| | pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] |
| |
|
| | scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() |
| | pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor |
| | pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor |
| | print("scale_factor:", scale_factor) |
| | |
| | output = model.decoder.forward( |
| | gaussians, |
| | pred_all_target_extrinsic, |
| | pred_all_target_intrinsic.float(), |
| | torch.ones(1, v, device=device) * 0.01, |
| | torch.ones(1, v, device=device) * 100, |
| | (h, w) |
| | ) |
| |
|
| | save_interpolated_video(pred_all_context_extrinsic, pred_all_context_intrinsic, b, h, w, gaussians, args.output_path, model.decoder) |
| | |
| | |
| | save_path = Path(args.output_path) |
| | |
| | for idx, (gt_image, pred_image) in enumerate(zip(tgt_images[0], output.color[0])): |
| | save_image(gt_image, save_path / "gt" / f"{idx:0>6}.jpg") |
| | save_image(pred_image, save_path / "pred" / f"{idx:0>6}.jpg") |
| |
|
| | |
| | psnr, ssim, lpips = compute_metrics(output.color[0], tgt_images[0]) |
| | print(f"PSNR: {psnr.mean():.2f}, SSIM: {ssim.mean():.3f}, LPIPS: {lpips.mean():.3f}") |
| |
|
| | if __name__ == "__main__": |
| | args = setup_args() |
| | evaluate(args) |
| |
|