import torch import torch.nn as nn import torch.nn.functional as F import torch.autograd.profiler as profiler import numpy as np from einops import rearrange, repeat, einsum from .math_utils import get_ray_limits_box, linspace from ...modules.diffusionmodules.openaimodel import Timestep class ImageEncoder(nn.Module): def __init__(self, output_dim: int = 64) -> None: super().__init__() self.output_dim = output_dim def forward(self, image): return image class PositionalEncoding(torch.nn.Module): """ Implement NeRF's positional encoding """ def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): super().__init__() self.num_freqs = num_freqs self.d_in = d_in self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) self.d_out = self.num_freqs * 2 * d_in self.include_input = include_input if include_input: self.d_out += d_in # f1 f1 f2 f2 ... to multiply x by self.register_buffer( "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) ) # 0 pi/2 0 pi/2 ... so that # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) _phases = torch.zeros(2 * self.num_freqs) _phases[1::2] = np.pi * 0.5 self.register_buffer("_phases", _phases.view(1, -1, 1)) def forward(self, x): """ Apply positional encoding (new implementation) :param x (batch, self.d_in) :return (batch, self.d_out) """ with profiler.record_function("positional_enc"): # embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) embed = repeat(x, "... C -> ... N C", N=self.num_freqs * 2) embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) embed = rearrange(embed, "... N C -> ... (N C)") if self.include_input: embed = torch.cat((x, embed), dim=-1) return embed class RayGenerator(torch.nn.Module): """ from camera pose and intrinsics to ray origins and directions """ def __init__(self): super().__init__() ( self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options, ) = (None, None, None, None, None) def forward(self, cam2world_matrix, intrinsics, render_size): """ Create batches of rays and return origins and directions. cam2world_matrix: (N, 4, 4) intrinsics: (N, 3, 3) render_size: int ray_origins: (N, M, 3) ray_dirs: (N, M, 2) """ N, M = cam2world_matrix.shape[0], render_size**2 cam_locs_world = cam2world_matrix[:, :3, 3] fx = intrinsics[:, 0, 0] fy = intrinsics[:, 1, 1] cx = intrinsics[:, 0, 2] cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] uv = torch.stack( torch.meshgrid( torch.arange( render_size, dtype=torch.float32, device=cam2world_matrix.device ), torch.arange( render_size, dtype=torch.float32, device=cam2world_matrix.device ), indexing="ij", ) ) uv = uv.flip(0).reshape(2, -1).transpose(1, 0) uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) z_cam = torch.ones((N, M), device=cam2world_matrix.device) x_lift = ( ( x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1) ) / fx.unsqueeze(-1) * z_cam ) y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam cam_rel_points = torch.stack( (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1 ) # NOTE: this should be named _blender2opencv _opencv2blender = ( torch.tensor( [ [1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1], ], dtype=torch.float32, device=cam2world_matrix.device, ) .unsqueeze(0) .repeat(N, 1, 1) ) cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) world_rel_points = torch.bmm( cam2world_matrix, cam_rel_points.permute(0, 2, 1) ).permute(0, 2, 1)[:, :, :3] ray_dirs = world_rel_points - cam_locs_world[:, None, :] ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) return ray_origins, ray_dirs class RaySampler(torch.nn.Module): def __init__( self, num_samples_per_ray, bbox_length=1.0, near=0.5, far=10000.0, disparity=False, ): super().__init__() self.num_samples_per_ray = num_samples_per_ray self.bbox_length = bbox_length self.near = near self.far = far self.disparity = disparity def forward(self, ray_origins, ray_directions): if not self.disparity: t_start, t_end = get_ray_limits_box( ray_origins, ray_directions, 2 * self.bbox_length ) else: t_start = torch.full_like(ray_origins, self.near) t_end = torch.full_like(ray_origins, self.far) is_ray_valid = t_end > t_start if torch.any(is_ray_valid).item(): t_start[~is_ray_valid] = t_start[is_ray_valid].min() t_end[~is_ray_valid] = t_start[is_ray_valid].max() if not self.disparity: depths = linspace(t_start, t_end, self.num_samples_per_ray) depths += ( torch.rand_like(depths) * (t_end - t_start) / (self.num_samples_per_ray - 1) ) else: step = 1.0 / self.num_samples_per_ray z_steps = torch.linspace( 0, 1 - step, self.num_samples_per_ray, device=ray_origins.device ) z_steps += torch.rand_like(z_steps) * step depths = 1 / (1 / self.near * (1 - z_steps) + 1 / self.far * z_steps) depths = depths[..., None, None, None] return ray_origins[None] + ray_directions[None] * depths class PixelNeRF(torch.nn.Module): def __init__( self, num_samples_per_ray: int = 128, feature_dim: int = 64, interp: str = "bilinear", padding: str = "border", disparity: bool = False, near: float = 0.5, far: float = 10000.0, use_feats_std: bool = False, use_pos_emb: bool = False, ) -> None: super().__init__() # self.positional_encoder = Timestep(3) # TODO self.num_samples_per_ray = num_samples_per_ray self.ray_generator = RayGenerator() self.ray_sampler = RaySampler( num_samples_per_ray, near=near, far=far, disparity=disparity ) # TODO self.interp = interp self.padding = padding self.positional_encoder = PositionalEncoding() # self.feature_aggregator = nn.Linear(128, 129) # TODO self.use_feats_std = use_feats_std self.use_pos_emb = use_pos_emb d_in = feature_dim if use_feats_std: d_in += feature_dim if use_pos_emb: d_in += self.positional_encoder.d_out self.feature_aggregator = nn.Sequential( nn.Linear(d_in, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 129), ) # self.decoder = nn.Linear(128, 131) # TODO self.decoder = nn.Sequential( nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 131), ) def project(self, ray_samples, source_c2ws, source_instrincs): # TODO: implement # S for number of source cameras # ray_samples: [B, N, H * W, N_sample, 3] # source_c2ws: [B, S, 4, 4] # source_intrinsics: [B, S, 3, 3] # return [B, S, N, H * W, N_sample, 2] S = source_c2ws.shape[1] B = ray_samples.shape[0] N = ray_samples.shape[1] HW = ray_samples.shape[2] ray_samples = repeat( ray_samples, "B N HW N_sample C -> B S N HW N_sample C", S=source_c2ws.shape[1], ) padding = torch.ones((B, S, N, HW, self.num_samples_per_ray, 1)).to(ray_samples) ray_samples_homo = torch.cat([ray_samples, padding], dim=-1) source_c2ws = repeat(source_c2ws, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) source_instrincs = repeat(source_instrincs, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) source_w2c = source_c2ws.inverse() projected_samples = einsum( source_w2c, ray_samples_homo, "... i j, ... j -> ... i" )[..., :3] # NOTE: assumes opengl convention projected_samples = -1 * projected_samples[..., :2] / projected_samples[..., 2:] # NOTE: intrinsics here are normalized by resolution fx = source_instrincs[..., 0, 0] fy = source_instrincs[..., 1, 1] cx = source_instrincs[..., 0, 2] cy = source_instrincs[..., 1, 2] x = projected_samples[..., 0] * fx + cx # negative sign here is caused by opengl, F.grid_sample is consistent with openCV convention y = -projected_samples[..., 1] * fy + cy return torch.stack([x, y], dim=-1) def forward( self, image_feats, source_c2ws, source_intrinsics, c2ws, intrinsics, render_size ): # image_feats: [B S C H W] B = c2ws.shape[0] T = c2ws.shape[1] ray_origins, ray_directions = self.ray_generator( c2ws.reshape(-1, 4, 4), intrinsics.reshape(-1, 3, 3), render_size ) # [B * N, H * W, 3] # breakpoint() ray_samples = self.ray_sampler( ray_origins, ray_directions ) # [N_sample, B * N, H * W, 3] ray_samples = rearrange(ray_samples, "Ns (B N) HW C -> B N HW Ns C", B=B) projected_samples = self.project(ray_samples, source_c2ws, source_intrinsics) # # debug # p = projected_samples[:, :, 0, :, 0, :] # p = p.reshape(p.shape[0] * p.shape[1], *p.shape[2:]) # breakpoint() # image_feats = repeat(image_feats, "B S C H W -> (B S N) C H W", N=T) image_feats = rearrange(image_feats, "B S C H W -> (B S) C H W") projected_samples = rearrange( projected_samples, "B S N HW Ns xy -> (B S) (N Ns) HW xy" ) # make sure the projected samples are in the range of [-1, 1], as required by F.grid_sample joint = F.grid_sample( image_feats, projected_samples * 2.0 - 1.0, padding_mode=self.padding, mode=self.interp, align_corners=True, ) # print("image_feats", image_feats.max(), image_feats.min()) # print("samples", projected_samples.max(), projected_samples.min()) joint = rearrange( joint, "(B S) C (N Ns) HW -> B S N HW Ns C", B=B, Ns=self.num_samples_per_ray, ) reduced = torch.mean(joint, dim=1) # reduce on source dimension if self.use_feats_std: if not joint.shape[1] == 1: reduced = torch.cat((reduced, joint.std(dim=1)), dim=-1) else: reduced = torch.cat((reduced, torch.zeros_like(reduced)), dim=-1) if self.use_pos_emb: reduced = torch.cat((reduced, self.positional_encoder(ray_samples)), dim=-1) reduced = self.feature_aggregator(reduced) feats, weights = reduced.split([reduced.shape[-1] - 1, 1], dim=-1) # feats: [B, N, H * W, N_samples, N_c] # weights: [B, N, H * W, N_samples, 1] weights = F.softmax(weights, dim=-2) feats = torch.sum(feats * weights, dim=-2) rgb, feats = self.decoder(feats).split([3, 128], dim=-1) rgb = F.sigmoid(rgb) rgb = rearrange(rgb, "B N (H W) C -> B N C H W", H=render_size) feats = rearrange(feats, "B N (H W) C -> B N C H W", H=render_size) # print(rgb.max(), rgb.min()) # print(feats.max(), feats.min()) return rgb, feats