import torch import torch.nn as nn from einops import rearrange from src.utils.gaussian_model import build_covariance from simple_knn._C import distCUDA2 from src.utils.sh_utils import RGB2SH class GaussianHead(nn.Module): def __init__(self, d_pt_feat=64, **kwargs): super().__init__() # args self.args = kwargs self.d_means = 3 self.d_scales = 3 self.d_rotations = 4 self.d_opacities = 1 self.sh_degree = 3 self.d_view_dep_features = 3 # RGB self.d_sh = (self.sh_degree + 1) ** 2 self.d_attr = (self.d_scales + self.d_rotations + self.d_opacities + self.d_view_dep_features * self.d_sh) if self.args.get('d_gs_feats'): self.d_attr += self.args['d_gs_feats'] # Create a mask for the spherical harmonics coefficients. # This ensures that at initialization, the coefficients are biased # towards having a large DC component and small view-dependent components. self.register_buffer( "sh_mask", torch.ones((self.d_sh,), dtype=torch.float32), persistent=False, ) for degree in range(1, self.sh_degree + 1): self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.5 * 0.25**degree self.gaussian_proj = nn.Linear(d_pt_feat, self.d_attr) # Activation functions self.scale_activation = torch.exp self.rotation_activation = torch.nn.functional.normalize self.opacity_activation = torch.sigmoid def forward(self, point_transformer_output, lseg_features=None): pred1 = {} pred2 = {} scene_scale = point_transformer_output['scale'] # B, 1, 1 scene_center = point_transformer_output['center'] # B, 1, 3 B, H, W, _ = point_transformer_output['shape'] normalized_means = point_transformer_output['coord'] # B * V * H * W, 3 colors = point_transformer_output['color'] # B * V * H * W, 3 # split normalized_means to 2 views normalized_means = rearrange(normalized_means, '(b v h w) c -> v b (h w) c', v=2, b=B, h=H, w=W) means = normalized_means * scene_scale + scene_center # V, B, H * W, 3 means = rearrange(means, 'v b (h w) c -> b (v h w) c', b=B, v=2, h=H, w=W) # get features feat = point_transformer_output['feat'] gaussian_attr = self.gaussian_proj(feat) # # split gaussian attributes # scales, rotations, opacities, sh_coeffs = torch.split(gaussian_attr, # [ # self.d_scales, # self.d_rotations, # self.d_opacities, # self.d_view_dep_features * self.d_sh # ], # dim=-1) scales, rotations, opacities, sh_coeffs, gs_feats = torch.split(gaussian_attr, [ self.d_scales, self.d_rotations, self.d_opacities, self.d_view_dep_features * self.d_sh, self.args['d_gs_feats'] ], dim=-1) # scales # calculate the distance between each point and its nearest neighbor all_dist = torch.stack([torch.sqrt(torch.clamp_min(distCUDA2(pts3d), 0.0000001)) for pts3d in means]) # B, V * H * W median_dist = all_dist.median(dim=-1)[0][:, None, None] # B, 1, 1 scales = self.scale_activation(scales) scales = rearrange(scales, '(b v h w) c -> b (v h w) c', b=B, v=2, h=H, w=W) scales = scales * all_dist[..., None] # clip scales scales = torch.clamp(scales, min=0.1 * median_dist, max=3.0 * median_dist) scales = rearrange(scales, 'b (v h w) c -> (b v h w) c', b=B, v=2, h=H, w=W) # activation rotations = self.rotation_activation(rotations) opacities = self.opacity_activation(opacities) # build covariance matrix covs = build_covariance(scales, rotations) # sh_mask sh_coeffs = rearrange(sh_coeffs, '(b v h w) (c d) -> (b v h w) c d', b=B, v=2, h=H, w=W, c=self.d_sh, d=self.d_view_dep_features) sh_dc = sh_coeffs[..., 0, :] sh_rest = sh_coeffs[..., 1:, :] if self.args.get('rgb_residual'): # denormalize colors colors = colors * 0.5 + 0.5 sh_rgb = RGB2SH(colors) # (B * V * H * W, 3) # add rgb residual to dc component sh_dc = sh_dc + sh_rgb # concatenate dc and rest sh_coeffs = torch.cat([sh_dc[..., None, :], sh_rest], dim=-2) sh_coeffs = sh_coeffs * self.sh_mask[None, :, None] # lseg_features(learning residual) lseg_features = rearrange(lseg_features, '(v b) c h w -> (b v h w) c', b=B, v=2, h=H, w=W) gs_feats = gs_feats + lseg_features # split to 2 views scales = rearrange(scales, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) rotations = rearrange(rotations, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) opacities = rearrange(opacities, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) sh_coeffs = rearrange(sh_coeffs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) covs = rearrange(covs, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) means = rearrange(means, 'b (v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) gs_feats = rearrange(gs_feats, '(b v h w) ... -> v b h w ...', v=2, b=B, h=H, w=W) pred1['scales'] = scales[0] pred1['rotations'] = rotations[0] pred1['covs'] = covs[0] pred1['opacities'] = opacities[0] pred1['sh_coeffs'] = sh_coeffs[0] pred1['means'] = means[0] pred1['gs_feats'] = gs_feats[0] pred2['scales'] = scales[1] pred2['rotations'] = rotations[1] pred2['covs'] = covs[1] pred2['opacities'] = opacities[1] pred2['sh_coeffs'] = sh_coeffs[1] pred2['means'] = means[1] pred2['gs_feats'] = gs_feats[1] return pred1, pred2