Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| import sys | |
| import json | |
| import gzip | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from einops import rearrange | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.evaluation.metrics import compute_lpips, compute_psnr, compute_ssim | |
| from misc.image_io import save_image, save_interpolated_video | |
| from src.utils.image import process_image | |
| from src.model.model.anysplat import AnySplat | |
| from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| def setup_args(): | |
| """Set up command-line arguments for the eval NVS script.""" | |
| parser = argparse.ArgumentParser(description='Test AnySplat on NVS evaluation') | |
| parser.add_argument('--data_dir', type=str, required=True, help='Path to NVS dataset') | |
| parser.add_argument('--llffhold', type=int, default=8, help='LLFF holdout') | |
| parser.add_argument('--output_path', type=str, default="outputs/nvs", help='Path to output directory') | |
| return parser.parse_args() | |
| def compute_metrics(pred_image, image): | |
| psnr = compute_psnr(pred_image, image) | |
| ssim = compute_ssim(pred_image, image) | |
| lpips = compute_lpips(pred_image, image) | |
| return psnr, ssim, lpips | |
| def evaluate(args: argparse.Namespace): | |
| model = AnySplat.from_pretrained("lhjiang/anysplat") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| os.makedirs(args.output_path, exist_ok=True) | |
| # load images | |
| image_folder = args.data_dir | |
| image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) | |
| images = [process_image(img_path) for img_path in image_names] | |
| ctx_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold != 0] | |
| tgt_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold == 0] | |
| ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) | |
| tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) | |
| ctx_images = (ctx_images+1)*0.5 | |
| tgt_images = (tgt_images+1)*0.5 | |
| b, v, _, h, w = tgt_images.shape | |
| # run inference | |
| encoder_output = model.encoder( | |
| ctx_images, | |
| global_step=0, | |
| visualization_dump={}, | |
| ) | |
| gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose | |
| num_context_view = ctx_images.shape[1] | |
| vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): | |
| aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] | |
| pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] | |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) | |
| extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) | |
| pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() | |
| pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w | |
| pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h | |
| pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] | |
| pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] | |
| scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() | |
| pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor | |
| pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor | |
| print("scale_factor:", scale_factor) | |
| output = model.decoder.forward( | |
| gaussians, | |
| pred_all_target_extrinsic, | |
| pred_all_target_intrinsic.float(), | |
| torch.ones(1, v, device=device) * 0.01, | |
| torch.ones(1, v, device=device) * 100, | |
| (h, w) | |
| ) | |
| save_interpolated_video(pred_all_context_extrinsic, pred_all_context_intrinsic, b, h, w, gaussians, args.output_path, model.decoder) | |
| # Save original images | |
| save_path = Path(args.output_path) | |
| # os.makedirs(save_path, exist_ok=True) | |
| for idx, (gt_image, pred_image) in enumerate(zip(tgt_images[0], output.color[0])): | |
| save_image(gt_image, save_path / "gt" / f"{idx:0>6}.jpg") | |
| save_image(pred_image, save_path / "pred" / f"{idx:0>6}.jpg") | |
| # compute metrics | |
| psnr, ssim, lpips = compute_metrics(output.color[0], tgt_images[0]) | |
| print(f"PSNR: {psnr.mean():.2f}, SSIM: {ssim.mean():.3f}, LPIPS: {lpips.mean():.3f}") | |
| if __name__ == "__main__": | |
| args = setup_args() | |
| evaluate(args) | |