hzxie's picture
feat: citydreamer inference (bugs to be fixed).
79df973 verified
# -*- coding: utf-8 -*-
#
# @File: __init__.py
# @Author: Jiaxiang Tang (@ashawkey)
# @Date: 2023-04-15 10:39:28
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-04-15 13:08:46
# @Email: ashawkey1999@gmail.com
# @Ref: https://github.com/ashawkey/torch-ngp
import math
import numpy as np
import torch
import grid_encoder_ext
class GridEncoderFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
embeddings,
offsets,
per_level_scale,
base_resolution,
calc_grad_inputs=False,
gridtype=0,
align_corners=False,
):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
# RETURN: [B, F], float
inputs = inputs.contiguous()
# batch size, coord dim
B, D = inputs.shape
# level
L = offsets.shape[0] - 1
# embedding dim for each level
C = embeddings.shape[1]
# resolution multiplier at each level, apply log2 for later CUDA exp2f
S = math.log2(per_level_scale)
# base resolution
H = base_resolution
# L first, optimize cache for cuda kernel, but needs an extra permute later
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
if calc_grad_inputs:
dy_dx = torch.empty(
B, L * D * C, device=inputs.device, dtype=embeddings.dtype
)
else:
dy_dx = torch.empty(
1, device=inputs.device, dtype=embeddings.dtype
) # placeholder... TODO: a better way?
grid_encoder_ext.forward(
inputs,
embeddings,
offsets,
outputs,
B,
D,
C,
L,
S,
H,
calc_grad_inputs,
dy_dx,
gridtype,
align_corners,
)
# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H, gridtype]
ctx.calc_grad_inputs = calc_grad_inputs
ctx.align_corners = align_corners
return outputs
@staticmethod
def backward(ctx, grad):
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H, gridtype = ctx.dims
calc_grad_inputs = ctx.calc_grad_inputs
align_corners = ctx.align_corners
# grad: [B, L * C] --> [L, B, C]
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
grad_embeddings = torch.zeros_like(embeddings)
if calc_grad_inputs:
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
else:
grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
grid_encoder_ext.backward(
grad,
inputs,
embeddings,
offsets,
grad_embeddings,
B,
D,
C,
L,
S,
H,
calc_grad_inputs,
dy_dx,
grad_inputs,
gridtype,
align_corners,
)
if calc_grad_inputs:
grad_inputs = grad_inputs.to(inputs.dtype)
return grad_inputs, grad_embeddings, None, None, None, None, None, None
else:
return None, grad_embeddings, None, None, None, None, None, None
class GridEncoder(torch.nn.Module):
def __init__(
self,
in_channels,
n_levels,
lvl_channels,
desired_resolution,
per_level_scale=2,
base_resolution=16,
log2_hashmap_size=19,
gridtype="hash",
align_corners=False,
):
super(GridEncoder, self).__init__()
self.in_channels = in_channels
self.n_levels = n_levels # num levels, each level multiply resolution by 2
self.lvl_channels = lvl_channels # encode channels per level
self.per_level_scale = 2 ** (
math.log2(desired_resolution / base_resolution) / (n_levels - 1)
)
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = n_levels * lvl_channels
self.gridtype = gridtype
self.gridtype_id = 0 if gridtype == "hash" else 1
self.align_corners = align_corners
# allocate parameters
offsets = []
offset = 0
self.max_params = 2**log2_hashmap_size
for i in range(n_levels):
resolution = int(math.ceil(base_resolution * per_level_scale**i))
params_in_level = min(
self.max_params,
(resolution if align_corners else resolution + 1) ** in_channels,
) # limit max number
params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible
offsets.append(offset)
offset += params_in_level
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer("offsets", offsets)
self.n_params = offsets[-1] * lvl_channels
self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels))
self._init_weights()
def _init_weights(self):
self.embeddings.data.uniform_(-1e-4, 1e-4)
def forward(self, inputs, bound=1):
# inputs: [..., in_channels], normalized real world positions in [-bound, bound]
# return: [..., n_levels * lvl_channels]
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.in_channels)
outputs = GridEncoderFunction.apply(
inputs,
self.embeddings,
self.offsets,
self.per_level_scale,
self.base_resolution,
inputs.requires_grad,
self.gridtype_id,
self.align_corners,
)
return outputs.view(prefix_shape + [self.output_dim])