Spaces:
Runtime error
Runtime error
import itertools | |
import logging as log | |
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def get_normalized_directions(directions): | |
"""SH encoding must be in the range [0, 1] | |
Args: | |
directions: batch of directions | |
""" | |
return (directions + 1.0) / 2.0 | |
def normalize_aabb(pts, aabb): | |
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 | |
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: | |
grid_dim = coords.shape[-1] | |
if grid.dim() == grid_dim + 1: | |
# no batch dimension present, need to add it | |
grid = grid.unsqueeze(0) | |
if coords.dim() == 2: | |
coords = coords.unsqueeze(0) | |
if grid_dim == 2 or grid_dim == 3: | |
grid_sampler = F.grid_sample | |
else: | |
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " | |
f"implemented for 2 and 3D data.") | |
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) | |
B, feature_dim = grid.shape[:2] | |
n = coords.shape[-2] | |
interp = grid_sampler( | |
grid, # [B, feature_dim, reso, ...] | |
coords, # [B, 1, ..., n, grid_dim] | |
align_corners=align_corners, | |
mode='bilinear', padding_mode='border') | |
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] | |
interp = interp.squeeze() # [B?, n, feature_dim?] | |
return interp | |
def init_grid_param( | |
grid_nd: int, | |
in_dim: int, | |
out_dim: int, | |
reso: Sequence[int], | |
a: float = 0.1, | |
b: float = 0.5): | |
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" | |
has_time_planes = in_dim == 4 | |
assert grid_nd <= in_dim | |
coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) | |
grid_coefs = nn.ParameterList() | |
for ci, coo_comb in enumerate(coo_combs): | |
new_grid_coef = nn.Parameter(torch.empty( | |
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] | |
)) | |
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 | |
nn.init.ones_(new_grid_coef) | |
else: | |
nn.init.uniform_(new_grid_coef, a=a, b=b) | |
grid_coefs.append(new_grid_coef) | |
return grid_coefs | |
def interpolate_ms_features(pts: torch.Tensor, | |
ms_grids: Collection[Iterable[nn.Module]], | |
grid_dimensions: int, | |
concat_features: bool, | |
num_levels: Optional[int], | |
) -> torch.Tensor: | |
coo_combs = list(itertools.combinations( | |
range(pts.shape[-1]), grid_dimensions) | |
) | |
if num_levels is None: | |
num_levels = len(ms_grids) | |
multi_scale_interp = [] if concat_features else 0. | |
grid: nn.ParameterList | |
for scale_id, grid in enumerate(ms_grids[:num_levels]): | |
interp_space = 1. | |
for ci, coo_comb in enumerate(coo_combs): | |
# interpolate in plane | |
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso | |
interp_out_plane = ( | |
grid_sample_wrapper(grid[ci], pts[..., coo_comb]) | |
.view(-1, feature_dim) | |
) | |
# compute product over planes | |
interp_space = interp_space * interp_out_plane | |
# combine over scales | |
if concat_features: | |
multi_scale_interp.append(interp_space) | |
else: | |
multi_scale_interp = multi_scale_interp + interp_space | |
if concat_features: | |
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) | |
return multi_scale_interp | |
class HexPlaneField(nn.Module): | |
def __init__( | |
self, | |
bounds, | |
planeconfig, | |
multires | |
) -> None: | |
super().__init__() | |
aabb = torch.tensor([[bounds,bounds,bounds], | |
[-bounds,-bounds,-bounds]]) | |
self.aabb = nn.Parameter(aabb, requires_grad=False) | |
self.grid_config = [planeconfig] | |
self.multiscale_res_multipliers = multires | |
self.concat_features = True | |
# 1. Init planes | |
self.grids = nn.ModuleList() | |
self.feat_dim = 0 | |
for res in self.multiscale_res_multipliers: | |
# initialize coordinate grid | |
config = self.grid_config[0].copy() | |
# Resolution fix: multi-res only on spatial planes | |
config["resolution"] = [ | |
r * res for r in config["resolution"][:3] | |
] + config["resolution"][3:] | |
gp = init_grid_param( | |
grid_nd=config["grid_dimensions"], | |
in_dim=config["input_coordinate_dim"], | |
out_dim=config["output_coordinate_dim"], | |
reso=config["resolution"], | |
) | |
# shape[1] is out-dim - Concatenate over feature len for each scale | |
if self.concat_features: | |
self.feat_dim += gp[-1].shape[1] | |
else: | |
self.feat_dim = gp[-1].shape[1] | |
self.grids.append(gp) | |
# print(f"Initialized model grids: {self.grids}") | |
print("feature_dim:",self.feat_dim) | |
def set_aabb(self,xyz_max, xyz_min): | |
aabb = torch.tensor([ | |
xyz_max, | |
xyz_min | |
]) | |
self.aabb = nn.Parameter(aabb,requires_grad=True) | |
print("Voxel Plane: set aabb=",self.aabb) | |
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): | |
"""Computes and returns the densities.""" | |
pts = normalize_aabb(pts, self.aabb) | |
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] | |
pts = pts.reshape(-1, pts.shape[-1]) | |
features = interpolate_ms_features( | |
pts, ms_grids=self.grids, # noqa | |
grid_dimensions=self.grid_config[0]["grid_dimensions"], | |
concat_features=self.concat_features, num_levels=None) | |
if len(features) < 1: | |
features = torch.zeros((0, 1)).to(features.device) | |
return features | |
def forward(self, | |
pts: torch.Tensor, | |
timestamps: Optional[torch.Tensor] = None): | |
features = self.get_density(pts, timestamps) | |
return features | |