|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from my.registry import Registry |
|
|
|
VOXRF_REGISTRY = Registry("VoxRF") |
|
|
|
|
|
def to_grid_samp_coords(xyz_sampled, aabb): |
|
|
|
aabbSize = aabb[1] - aabb[0] |
|
return (xyz_sampled - aabb[0]) / aabbSize * 2 - 1 |
|
|
|
|
|
def add_non_state_tsr(nn_module, key, val): |
|
|
|
nn_module.register_buffer(key, val, persistent=False) |
|
|
|
|
|
@VOXRF_REGISTRY.register() |
|
class VoxRF(nn.Module): |
|
def __init__( |
|
self, aabb, grid_size, step_ratio=0.5, |
|
density_shift=-10, ray_march_weight_thres=0.0001, c=3, |
|
blend_bg_texture=True, bg_texture_hw=64 |
|
): |
|
assert aabb.shape == (2, 3) |
|
xyz = grid_size |
|
del grid_size |
|
|
|
super().__init__() |
|
add_non_state_tsr(self, "aabb", torch.tensor(aabb, dtype=torch.float32)) |
|
add_non_state_tsr(self, "grid_size", torch.LongTensor(xyz)) |
|
|
|
self.density_shift = density_shift |
|
self.ray_march_weight_thres = ray_march_weight_thres |
|
self.step_ratio = step_ratio |
|
|
|
zyx = xyz[::-1] |
|
self.density = torch.nn.Parameter( |
|
torch.zeros((1, 1, *zyx)) |
|
) |
|
self.color = torch.nn.Parameter( |
|
torch.randn((1, c, *zyx)) |
|
) |
|
|
|
self.blend_bg_texture = blend_bg_texture |
|
self.bg = torch.nn.Parameter( |
|
torch.randn((1, c, bg_texture_hw, bg_texture_hw)) |
|
) |
|
|
|
self.c = c |
|
self.alphaMask = None |
|
self.feats2color = lambda feats: torch.sigmoid(feats) |
|
|
|
self.d_scale = torch.nn.Parameter(torch.tensor(0.0)) |
|
|
|
@property |
|
def device(self): |
|
return self.density.device |
|
|
|
def compute_density_feats(self, xyz_sampled): |
|
xyz_sampled = to_grid_samp_coords(xyz_sampled, self.aabb) |
|
n = xyz_sampled.shape[0] |
|
xyz_sampled = xyz_sampled.reshape(1, n, 1, 1, 3) |
|
σ = F.grid_sample(self.density, xyz_sampled).view(n) |
|
|
|
|
|
|
|
|
|
|
|
|
|
σ = σ * torch.exp(self.d_scale) |
|
σ = F.softplus(σ + self.density_shift) |
|
return σ |
|
|
|
def compute_app_feats(self, xyz_sampled): |
|
xyz_sampled = to_grid_samp_coords(xyz_sampled, self.aabb) |
|
n = xyz_sampled.shape[0] |
|
xyz_sampled = xyz_sampled.reshape(1, n, 1, 1, 3) |
|
feats = F.grid_sample(self.color, xyz_sampled).view(self.c, n) |
|
feats = feats.T |
|
return feats |
|
|
|
def compute_bg(self, uv): |
|
n = uv.shape[0] |
|
uv = uv.reshape(1, n, 1, 2) |
|
feats = F.grid_sample(self.bg, uv).view(self.c, n) |
|
feats = feats.T |
|
return feats |
|
|
|
def get_per_voxel_length(self): |
|
aabb_size = self.aabb[1] - self.aabb[0] |
|
|
|
|
|
|
|
vox_xyz_length = aabb_size / self.grid_size |
|
return vox_xyz_length |
|
|
|
def get_num_samples(self, max_size=None): |
|
|
|
unit = torch.mean(self.get_per_voxel_length()) |
|
step_size = unit * self.step_ratio |
|
step_size = step_size.item() |
|
|
|
if max_size is None: |
|
aabb_size = self.aabb[1] - self.aabb[0] |
|
aabb_diag = torch.norm(aabb_size) |
|
max_size = aabb_diag |
|
|
|
num_samples = int((max_size / step_size).item()) + 1 |
|
return num_samples, step_size |
|
|
|
@torch.no_grad() |
|
def resample(self, target_xyz: list): |
|
zyx = target_xyz[::-1] |
|
self.density = self._resamp_param(self.density, zyx) |
|
self.color = self._resamp_param(self.color, zyx) |
|
target_xyz = torch.LongTensor(target_xyz).to(self.aabb.device) |
|
add_non_state_tsr(self, "grid_size", target_xyz) |
|
|
|
@staticmethod |
|
def _resamp_param(param, target_size): |
|
return torch.nn.Parameter(F.interpolate( |
|
param.data, size=target_size, mode="trilinear" |
|
)) |
|
|
|
@torch.no_grad() |
|
def compute_volume_alpha(self): |
|
xyz = self.grid_size.tolist() |
|
unit_xyz = self.get_per_voxel_length() |
|
xs, ys, zs = torch.meshgrid( |
|
*[torch.arange(nd) for nd in xyz], indexing="ij" |
|
) |
|
pts = torch.stack([xs, ys, zs], dim=-1).to(unit_xyz.device) |
|
pts = self.aabb[0] + (pts + 0.5) * unit_xyz |
|
pts = pts.reshape(-1, 3) |
|
|
|
σ = self.compute_density_feats(pts) |
|
d = torch.mean(unit_xyz) |
|
α = 1 - torch.exp(-σ * d) |
|
α = rearrange(α.view(xyz), "x y z -> 1 1 z y x") |
|
α = α.contiguous() |
|
return α |
|
|
|
@torch.no_grad() |
|
def make_alpha_mask(self): |
|
α = self.compute_volume_alpha() |
|
ks = 3 |
|
α = F.max_pool3d(α, kernel_size=ks, padding=ks // 2, stride=1) |
|
α = (α > 0.08).float() |
|
vol_mask = AlphaMask(self.aabb, α) |
|
self.alphaMask = vol_mask |
|
|
|
def state_dict(self, *args, **kwargs): |
|
state = super().state_dict(*args, **kwargs) |
|
if self.alphaMask is not None: |
|
state['alpha_mask'] = self.alphaMask.export_state() |
|
return state |
|
|
|
def load_state_dict(self, state_dict): |
|
if 'alpha_mask' in state_dict.keys(): |
|
state = state_dict.pop("alpha_mask") |
|
self.alphaMask = AlphaMask.from_state(state) |
|
return super().load_state_dict(state_dict, strict=True) |
|
|
|
|
|
@VOXRF_REGISTRY.register() |
|
class V_SJC(VoxRF): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.feats2color = lambda feats: torch.sigmoid(feats) * 2 - 1 |
|
|
|
def opt_params(self): |
|
groups = [] |
|
for name, param in self.named_parameters(): |
|
|
|
grp = {"params": param} |
|
if name in ["bg"]: |
|
grp["lr"] = 0.0001 |
|
if name in ["density"]: |
|
|
|
pass |
|
groups.append(grp) |
|
return groups |
|
|
|
def annealed_opt_params(self, base_lr, σ): |
|
groups = [] |
|
for name, param in self.named_parameters(): |
|
|
|
grp = {"params": param, "lr": base_lr * σ} |
|
if name in ["density"]: |
|
grp["lr"] = base_lr * σ |
|
if name in ["d_scale"]: |
|
grp["lr"] = 0. |
|
if name in ["color"]: |
|
grp["lr"] = base_lr * σ |
|
if name in ["bg"]: |
|
grp["lr"] = 0.01 |
|
groups.append(grp) |
|
return groups |
|
|
|
|
|
@VOXRF_REGISTRY.register() |
|
class V_SD(V_SJC): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.feats2color = lambda feats: feats |
|
|
|
|
|
class AlphaMask(nn.Module): |
|
def __init__(self, aabb, alphas): |
|
super().__init__() |
|
zyx = list(alphas.shape[-3:]) |
|
add_non_state_tsr(self, "alphas", alphas.view(1, 1, *zyx)) |
|
xyz = zyx[::-1] |
|
add_non_state_tsr(self, "grid_size", torch.LongTensor(xyz)) |
|
add_non_state_tsr(self, "aabb", aabb) |
|
|
|
def sample_alpha(self, xyz_pts): |
|
xyz_pts = to_grid_samp_coords(xyz_pts, self.aabb) |
|
xyz_pts = xyz_pts.view(1, -1, 1, 1, 3) |
|
α = F.grid_sample(self.alphas, xyz_pts).view(-1) |
|
return α |
|
|
|
def export_state(self): |
|
state = {} |
|
alphas = self.alphas.bool().cpu().numpy() |
|
state['shape'] = alphas.shape |
|
state['mask'] = np.packbits(alphas.reshape(-1)) |
|
state['aabb'] = self.aabb.cpu() |
|
return state |
|
|
|
@classmethod |
|
def from_state(cls, state): |
|
shape = state['shape'] |
|
mask = state['mask'] |
|
aabb = state['aabb'] |
|
|
|
length = np.prod(shape) |
|
alphas = torch.from_numpy( |
|
np.unpackbits(mask)[:length].reshape(shape) |
|
) |
|
amask = cls(aabb, alphas.float()) |
|
return amask |
|
|
|
|
|
def test(): |
|
device = torch.device("cuda:1") |
|
|
|
aabb = 1.5 * np.array([ |
|
[-1, -1, -1], |
|
[1, 1, 1] |
|
]) |
|
model = VoxRF(aabb, [10, 20, 30]) |
|
model.to(device) |
|
print(model.density.shape) |
|
print(model.grid_size) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|