dreamgaussian4d / scene /hexplane.py
jiaweir
init
21c4e64
raw
history blame
No virus
6.41 kB
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_normalized_directions(directions):
"""SH encoding must be in the range [0, 1]
Args:
directions: batch of directions
"""
return (directions + 1.0) / 2.0
def normalize_aabb(pts, aabb):
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
grid_dim = coords.shape[-1]
if grid.dim() == grid_dim + 1:
# no batch dimension present, need to add it
grid = grid.unsqueeze(0)
if coords.dim() == 2:
coords = coords.unsqueeze(0)
if grid_dim == 2 or grid_dim == 3:
grid_sampler = F.grid_sample
else:
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
f"implemented for 2 and 3D data.")
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
B, feature_dim = grid.shape[:2]
n = coords.shape[-2]
interp = grid_sampler(
grid, # [B, feature_dim, reso, ...]
coords, # [B, 1, ..., n, grid_dim]
align_corners=align_corners,
mode='bilinear', padding_mode='border')
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim]
interp = interp.squeeze() # [B?, n, feature_dim?]
return interp
def init_grid_param(
grid_nd: int,
in_dim: int,
out_dim: int,
reso: Sequence[int],
a: float = 0.1,
b: float = 0.5):
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
has_time_planes = in_dim == 4
assert grid_nd <= in_dim
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
grid_coefs = nn.ParameterList()
for ci, coo_comb in enumerate(coo_combs):
new_grid_coef = nn.Parameter(torch.empty(
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
))
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1
nn.init.ones_(new_grid_coef)
else:
nn.init.uniform_(new_grid_coef, a=a, b=b)
grid_coefs.append(new_grid_coef)
return grid_coefs
def interpolate_ms_features(pts: torch.Tensor,
ms_grids: Collection[Iterable[nn.Module]],
grid_dimensions: int,
concat_features: bool,
num_levels: Optional[int],
) -> torch.Tensor:
coo_combs = list(itertools.combinations(
range(pts.shape[-1]), grid_dimensions)
)
if num_levels is None:
num_levels = len(ms_grids)
multi_scale_interp = [] if concat_features else 0.
grid: nn.ParameterList
for scale_id, grid in enumerate(ms_grids[:num_levels]):
interp_space = 1.
for ci, coo_comb in enumerate(coo_combs):
# interpolate in plane
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso
interp_out_plane = (
grid_sample_wrapper(grid[ci], pts[..., coo_comb])
.view(-1, feature_dim)
)
# compute product over planes
interp_space = interp_space * interp_out_plane
# combine over scales
if concat_features:
multi_scale_interp.append(interp_space)
else:
multi_scale_interp = multi_scale_interp + interp_space
if concat_features:
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
return multi_scale_interp
class HexPlaneField(nn.Module):
def __init__(
self,
bounds,
planeconfig,
multires
) -> None:
super().__init__()
aabb = torch.tensor([[bounds,bounds,bounds],
[-bounds,-bounds,-bounds]])
self.aabb = nn.Parameter(aabb, requires_grad=False)
self.grid_config = [planeconfig]
self.multiscale_res_multipliers = multires
self.concat_features = True
# 1. Init planes
self.grids = nn.ModuleList()
self.feat_dim = 0
for res in self.multiscale_res_multipliers:
# initialize coordinate grid
config = self.grid_config[0].copy()
# Resolution fix: multi-res only on spatial planes
config["resolution"] = [
r * res for r in config["resolution"][:3]
] + config["resolution"][3:]
gp = init_grid_param(
grid_nd=config["grid_dimensions"],
in_dim=config["input_coordinate_dim"],
out_dim=config["output_coordinate_dim"],
reso=config["resolution"],
)
# shape[1] is out-dim - Concatenate over feature len for each scale
if self.concat_features:
self.feat_dim += gp[-1].shape[1]
else:
self.feat_dim = gp[-1].shape[1]
self.grids.append(gp)
# print(f"Initialized model grids: {self.grids}")
print("feature_dim:",self.feat_dim)
def set_aabb(self,xyz_max, xyz_min):
aabb = torch.tensor([
xyz_max,
xyz_min
])
self.aabb = nn.Parameter(aabb,requires_grad=True)
print("Voxel Plane: set aabb=",self.aabb)
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
"""Computes and returns the densities."""
pts = normalize_aabb(pts, self.aabb)
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
pts = pts.reshape(-1, pts.shape[-1])
features = interpolate_ms_features(
pts, ms_grids=self.grids, # noqa
grid_dimensions=self.grid_config[0]["grid_dimensions"],
concat_features=self.concat_features, num_levels=None)
if len(features) < 1:
features = torch.zeros((0, 1)).to(features.device)
return features
def forward(self,
pts: torch.Tensor,
timestamps: Optional[torch.Tensor] = None):
features = self.get_density(pts, timestamps)
return features