import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from my.registry import Registry VOXRF_REGISTRY = Registry("VoxRF") def to_grid_samp_coords(xyz_sampled, aabb): # output range is [-1, 1] aabbSize = aabb[1] - aabb[0] return (xyz_sampled - aabb[0]) / aabbSize * 2 - 1 def add_non_state_tsr(nn_module, key, val): # tsr added here does not appear in module's state_dict; nn_module.register_buffer(key, val, persistent=False) @VOXRF_REGISTRY.register() class VoxRF(nn.Module): def __init__( self, aabb, grid_size, step_ratio=0.5, density_shift=-10, ray_march_weight_thres=0.0001, c=3, blend_bg_texture=True, bg_texture_hw=64 ): assert aabb.shape == (2, 3) xyz = grid_size del grid_size super().__init__() add_non_state_tsr(self, "aabb", torch.tensor(aabb, dtype=torch.float32)) add_non_state_tsr(self, "grid_size", torch.LongTensor(xyz)) self.density_shift = density_shift self.ray_march_weight_thres = ray_march_weight_thres self.step_ratio = step_ratio zyx = xyz[::-1] self.density = torch.nn.Parameter( torch.zeros((1, 1, *zyx)) ) self.color = torch.nn.Parameter( torch.randn((1, c, *zyx)) ) self.blend_bg_texture = blend_bg_texture self.bg = torch.nn.Parameter( torch.randn((1, c, bg_texture_hw, bg_texture_hw)) ) self.c = c self.alphaMask = None self.feats2color = lambda feats: torch.sigmoid(feats) self.d_scale = torch.nn.Parameter(torch.tensor(0.0)) @property def device(self): return self.density.device def compute_density_feats(self, xyz_sampled): xyz_sampled = to_grid_samp_coords(xyz_sampled, self.aabb) n = xyz_sampled.shape[0] xyz_sampled = xyz_sampled.reshape(1, n, 1, 1, 3) σ = F.grid_sample(self.density, xyz_sampled).view(n) # We notice that DreamFusion also uses an exp scaling on densities. # The technique here is developed BEFORE DreamFusion came out, # and forms part of our upcoming technical report discussing invariant # scaling for volume rendering. The reseach was presented to our # funding agency (TRI) on Aug. 25th, and discussed with a few researcher friends # during the period. σ = σ * torch.exp(self.d_scale) σ = F.softplus(σ + self.density_shift) return σ def compute_app_feats(self, xyz_sampled): xyz_sampled = to_grid_samp_coords(xyz_sampled, self.aabb) n = xyz_sampled.shape[0] xyz_sampled = xyz_sampled.reshape(1, n, 1, 1, 3) feats = F.grid_sample(self.color, xyz_sampled).view(self.c, n) feats = feats.T return feats def compute_bg(self, uv): n = uv.shape[0] uv = uv.reshape(1, n, 1, 2) feats = F.grid_sample(self.bg, uv).view(self.c, n) feats = feats.T return feats def get_per_voxel_length(self): aabb_size = self.aabb[1] - self.aabb[0] # NOTE I am not -1 on grid_size here; # I interpret a voxel as a square and val sits at the center; like pixel # this is consistent with align_corners=False vox_xyz_length = aabb_size / self.grid_size return vox_xyz_length def get_num_samples(self, max_size=None): # funny way to set step size; whatever unit = torch.mean(self.get_per_voxel_length()) step_size = unit * self.step_ratio step_size = step_size.item() # get the float if max_size is None: aabb_size = self.aabb[1] - self.aabb[0] aabb_diag = torch.norm(aabb_size) max_size = aabb_diag num_samples = int((max_size / step_size).item()) + 1 return num_samples, step_size @torch.no_grad() def resample(self, target_xyz: list): zyx = target_xyz[::-1] self.density = self._resamp_param(self.density, zyx) self.color = self._resamp_param(self.color, zyx) target_xyz = torch.LongTensor(target_xyz).to(self.aabb.device) add_non_state_tsr(self, "grid_size", target_xyz) @staticmethod def _resamp_param(param, target_size): return torch.nn.Parameter(F.interpolate( param.data, size=target_size, mode="trilinear" )) @torch.no_grad() def compute_volume_alpha(self): xyz = self.grid_size.tolist() unit_xyz = self.get_per_voxel_length() xs, ys, zs = torch.meshgrid( *[torch.arange(nd) for nd in xyz], indexing="ij" ) pts = torch.stack([xs, ys, zs], dim=-1).to(unit_xyz.device) # [nx, ny, nz, 3] pts = self.aabb[0] + (pts + 0.5) * unit_xyz pts = pts.reshape(-1, 3) # could potentially filter with alpha mask itself if exists σ = self.compute_density_feats(pts) d = torch.mean(unit_xyz) α = 1 - torch.exp(-σ * d) α = rearrange(α.view(xyz), "x y z -> 1 1 z y x") α = α.contiguous() return α @torch.no_grad() def make_alpha_mask(self): α = self.compute_volume_alpha() ks = 3 α = F.max_pool3d(α, kernel_size=ks, padding=ks // 2, stride=1) α = (α > 0.08).float() vol_mask = AlphaMask(self.aabb, α) self.alphaMask = vol_mask def state_dict(self, *args, **kwargs): state = super().state_dict(*args, **kwargs) if self.alphaMask is not None: state['alpha_mask'] = self.alphaMask.export_state() return state def load_state_dict(self, state_dict): if 'alpha_mask' in state_dict.keys(): state = state_dict.pop("alpha_mask") self.alphaMask = AlphaMask.from_state(state) return super().load_state_dict(state_dict, strict=True) @VOXRF_REGISTRY.register() class V_SJC(VoxRF): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # rendering color in [-1, 1] range, since score models all operate on centered img self.feats2color = lambda feats: torch.sigmoid(feats) * 2 - 1 def opt_params(self): groups = [] for name, param in self.named_parameters(): # print(f"{name} {param.shape}") grp = {"params": param} if name in ["bg"]: grp["lr"] = 0.0001 if name in ["density"]: # grp["lr"] = 0. pass groups.append(grp) return groups def annealed_opt_params(self, base_lr, σ): groups = [] for name, param in self.named_parameters(): # print(f"{name} {param.shape}") grp = {"params": param, "lr": base_lr * σ} if name in ["density"]: grp["lr"] = base_lr * σ if name in ["d_scale"]: grp["lr"] = 0. if name in ["color"]: grp["lr"] = base_lr * σ if name in ["bg"]: grp["lr"] = 0.01 groups.append(grp) return groups @VOXRF_REGISTRY.register() class V_SD(V_SJC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # rendering in feature space; no sigmoid thresholding self.feats2color = lambda feats: feats class AlphaMask(nn.Module): def __init__(self, aabb, alphas): super().__init__() zyx = list(alphas.shape[-3:]) add_non_state_tsr(self, "alphas", alphas.view(1, 1, *zyx)) xyz = zyx[::-1] add_non_state_tsr(self, "grid_size", torch.LongTensor(xyz)) add_non_state_tsr(self, "aabb", aabb) def sample_alpha(self, xyz_pts): xyz_pts = to_grid_samp_coords(xyz_pts, self.aabb) xyz_pts = xyz_pts.view(1, -1, 1, 1, 3) α = F.grid_sample(self.alphas, xyz_pts).view(-1) return α def export_state(self): state = {} alphas = self.alphas.bool().cpu().numpy() state['shape'] = alphas.shape state['mask'] = np.packbits(alphas.reshape(-1)) state['aabb'] = self.aabb.cpu() return state @classmethod def from_state(cls, state): shape = state['shape'] mask = state['mask'] aabb = state['aabb'] length = np.prod(shape) alphas = torch.from_numpy( np.unpackbits(mask)[:length].reshape(shape) ) amask = cls(aabb, alphas.float()) return amask def test(): device = torch.device("cuda:1") aabb = 1.5 * np.array([ [-1, -1, -1], [1, 1, 1] ]) model = VoxRF(aabb, [10, 20, 30]) model.to(device) print(model.density.shape) print(model.grid_size) return if __name__ == "__main__": test()