| import argparse |
| import math |
| import os |
| import time |
|
|
| import imageio |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import tqdm |
| import viser |
| from pathlib import Path |
| from gsplat._helper import load_test_data |
| from gsplat.distributed import cli |
| from gsplat.rendering import rasterization |
|
|
| from nerfview import CameraState, RenderTabState, apply_float_colormap |
| from examples.gsplat_viewer import GsplatViewer, GsplatRenderTabState |
|
|
|
|
| def main(local_rank: int, world_rank, world_size: int, args): |
| torch.manual_seed(42) |
| device = torch.device("cuda", local_rank) |
|
|
| if args.ckpt is None: |
| ( |
| means, |
| quats, |
| scales, |
| opacities, |
| colors, |
| viewmats, |
| Ks, |
| width, |
| height, |
| ) = load_test_data(device=device, scene_grid=args.scene_grid) |
|
|
| assert world_size <= 2 |
| means = means[world_rank::world_size].contiguous() |
| means.requires_grad = True |
| quats = quats[world_rank::world_size].contiguous() |
| quats.requires_grad = True |
| scales = scales[world_rank::world_size].contiguous() |
| scales.requires_grad = True |
| opacities = opacities[world_rank::world_size].contiguous() |
| opacities.requires_grad = True |
| colors = colors[world_rank::world_size].contiguous() |
| colors.requires_grad = True |
|
|
| viewmats = viewmats[world_rank::world_size][:1].contiguous() |
| Ks = Ks[world_rank::world_size][:1].contiguous() |
|
|
| sh_degree = None |
| C = len(viewmats) |
| N = len(means) |
| print("rank", world_rank, "Number of Gaussians:", N, "Number of Cameras:", C) |
|
|
| |
| for _ in tqdm.trange(1): |
| render_colors, render_alphas, meta = rasterization( |
| means, |
| quats, |
| scales, |
| opacities, |
| colors, |
| viewmats, |
| Ks, |
| width, |
| height, |
| render_mode="RGB+D", |
| packed=False, |
| distributed=world_size > 1, |
| ) |
| C = render_colors.shape[0] |
| assert render_colors.shape == (C, height, width, 4) |
| assert render_alphas.shape == (C, height, width, 1) |
| render_colors.sum().backward() |
|
|
| render_rgbs = render_colors[..., 0:3] |
| render_depths = render_colors[..., 3:4] |
| render_depths = render_depths / render_depths.max() |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| canvas = ( |
| torch.cat( |
| [ |
| render_rgbs.reshape(C * height, width, 3), |
| render_depths.reshape(C * height, width, 1).expand(-1, -1, 3), |
| render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3), |
| ], |
| dim=1, |
| ) |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| imageio.imsave( |
| f"{args.output_dir}/render_rank{world_rank}.png", |
| (canvas * 255).astype(np.uint8), |
| ) |
| else: |
| means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] |
| for ckpt_path in args.ckpt: |
| ckpt = torch.load(ckpt_path, map_location=device)["splats"] |
| means.append(ckpt["means"]) |
| quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) |
| scales.append(torch.exp(ckpt["scales"])) |
| opacities.append(torch.sigmoid(ckpt["opacities"])) |
| sh0.append(ckpt["sh0"]) |
| shN.append(ckpt["shN"]) |
| means = torch.cat(means, dim=0) |
| quats = torch.cat(quats, dim=0) |
| scales = torch.cat(scales, dim=0) |
| opacities = torch.cat(opacities, dim=0) |
| sh0 = torch.cat(sh0, dim=0) |
| shN = torch.cat(shN, dim=0) |
| colors = torch.cat([sh0, shN], dim=-2) |
| sh_degree = int(math.sqrt(colors.shape[-2]) - 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print("Number of Gaussians:", len(means)) |
|
|
| |
| @torch.no_grad() |
| def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState): |
| assert isinstance(render_tab_state, GsplatRenderTabState) |
| if render_tab_state.preview_render: |
| width = render_tab_state.render_width |
| height = render_tab_state.render_height |
| else: |
| width = render_tab_state.viewer_width |
| height = render_tab_state.viewer_height |
| c2w = camera_state.c2w |
| K = camera_state.get_K((width, height)) |
| c2w = torch.from_numpy(c2w).float().to(device) |
| K = torch.from_numpy(K).float().to(device) |
| viewmat = c2w.inverse() |
|
|
| RENDER_MODE_MAP = { |
| "rgb": "RGB", |
| "depth(accumulated)": "D", |
| "depth(expected)": "ED", |
| "alpha": "RGB", |
| } |
|
|
| render_colors, render_alphas, info = rasterization( |
| means, |
| quats, |
| scales, |
| opacities, |
| colors, |
| viewmat[None], |
| K[None], |
| width, |
| height, |
| sh_degree=( |
| min(render_tab_state.max_sh_degree, sh_degree) |
| if sh_degree is not None |
| else None |
| ), |
| near_plane=render_tab_state.near_plane, |
| far_plane=render_tab_state.far_plane, |
| radius_clip=render_tab_state.radius_clip, |
| eps2d=render_tab_state.eps2d, |
| backgrounds=torch.tensor([render_tab_state.backgrounds], device=device) |
| / 255.0, |
| render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], |
| rasterize_mode=render_tab_state.rasterize_mode, |
| camera_model=render_tab_state.camera_model, |
| ) |
| render_tab_state.total_gs_count = len(means) |
| render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() |
|
|
| if render_tab_state.render_mode == "rgb": |
| |
| render_colors = render_colors[0, ..., 0:3].clamp(0, 1) |
| renders = render_colors.cpu().numpy() |
| elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: |
| |
| depth = render_colors[0, ..., 0:1] |
| if render_tab_state.normalize_nearfar: |
| near_plane = render_tab_state.near_plane |
| far_plane = render_tab_state.far_plane |
| else: |
| near_plane = depth.min() |
| far_plane = depth.max() |
| depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| depth_norm = torch.clip(depth_norm, 0, 1) |
| if render_tab_state.inverse: |
| depth_norm = 1 - depth_norm |
| renders = ( |
| apply_float_colormap(depth_norm, render_tab_state.colormap) |
| .cpu() |
| .numpy() |
| ) |
| elif render_tab_state.render_mode == "alpha": |
| alpha = render_alphas[0, ..., 0:1] |
| if render_tab_state.inverse: |
| alpha = 1 - alpha |
| renders = ( |
| apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() |
| ) |
| return renders |
|
|
| server = viser.ViserServer(port=args.port, verbose=False) |
| _ = GsplatViewer( |
| server=server, |
| render_fn=viewer_render_fn, |
| output_dir=Path(args.output_dir), |
| mode="rendering", |
| ) |
| print("Viewer running... Ctrl+C to exit.") |
| time.sleep(100000) |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| # Use single GPU to view the scene |
| CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \ |
| --ckpt results/garden/ckpts/ckpt_6999_rank0.pt \ |
| --output_dir results/garden/ \ |
| --port 8082 |
| |
| CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \ |
| --output_dir results/garden/ \ |
| --port 8082 |
| """ |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--output_dir", type=str, default="results/", help="where to dump outputs" |
| ) |
| parser.add_argument( |
| "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN" |
| ) |
| parser.add_argument( |
| "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file" |
| ) |
| parser.add_argument( |
| "--port", type=int, default=8080, help="port for the viewer server" |
| ) |
| args = parser.parse_args() |
| assert args.scene_grid % 2 == 1, "scene_grid must be odd" |
|
|
| cli(main, args, verbose=True) |
|
|