| | from dataclasses import dataclass |
| | from typing import Literal, Optional, List |
| |
|
| | import torch |
| | from einops import rearrange, repeat |
| | from jaxtyping import Float |
| | from torch import Tensor, nn |
| | import MinkowskiEngine as ME |
| | import torch.nn.init as init |
| |
|
| | from ...dataset.shims.patch_shim import apply_patch_shim |
| | from ...dataset.types import BatchedExample, DataShim |
| | from ...geometry.projection import sample_image_grid |
| | from ..types import Gaussians |
| |
|
| | |
| | |
| | from .common.guassian_adapter_depth import GaussianAdapter_depth, GaussianAdapterCfg |
| |
|
| |
|
| | from .encoder import Encoder |
| | from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg |
| |
|
| | import torchvision.transforms as T |
| | import torch.nn.functional as F |
| |
|
| | from .unimatch.mv_unimatch import MultiViewUniMatch |
| | from .unimatch.dpt_head import DPTHead |
| |
|
| | from .common.voxel_feature import project_features_to_3d, project_features_to_voxel, adapte_features_to_voxel, adapte_project_features_to_3d |
| | from .common.me_fea import project_features_to_me |
| |
|
| | from ...geometry.projection import get_world_rays |
| | from .common.sparse_net import SparseGaussianHead, SparseUNetWithAttention |
| | from .common.mink_resnet import MultiScaleSparseHead |
| |
|
| | from ...test.export_ply import save_point_cloud_to_ply |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | def print_mem(tag: str = ""): |
| | if not torch.cuda.is_available(): |
| | print(f"[MEM] {tag} - no CUDA") |
| | return |
| | allocated = torch.cuda.memory_allocated() / 1024**2 |
| | reserved = torch.cuda.memory_reserved() / 1024**2 |
| | print(f"[MEM] {tag} | allocated={allocated:.1f} MB reserved={reserved:.1f} MB") |
| |
|
| | @dataclass |
| | class EncoderDepthSplatCfg: |
| | name: Literal["depthsplat"] |
| | d_feature: int |
| | num_depth_candidates: int |
| | num_surfaces: int |
| | visualizer: EncoderVisualizerDepthSplatCfg |
| | gaussian_adapter: GaussianAdapterCfg |
| | gaussians_per_pixel: int |
| | unimatch_weights_path: str | None |
| | downscale_factor: int |
| | shim_patch_size: int |
| | multiview_trans_attn_split: int |
| | costvolume_unet_feat_dim: int |
| | costvolume_unet_channel_mult: List[int] |
| | costvolume_unet_attn_res: List[int] |
| | depth_unet_feat_dim: int |
| | depth_unet_attn_res: List[int] |
| | depth_unet_channel_mult: List[int] |
| |
|
| | |
| | num_scales: int |
| | upsample_factor: int |
| | lowest_feature_resolution: int |
| | depth_unet_channels: int |
| | grid_sample_disable_cudnn: bool |
| |
|
| | |
| | large_gaussian_head: bool |
| | color_large_unet: bool |
| | init_sh_input_img: bool |
| | feature_upsampler_channels: int |
| | gaussian_regressor_channels: int |
| |
|
| | |
| | supervise_intermediate_depth: bool |
| | return_depth: bool |
| |
|
| | |
| | train_depth_only: bool |
| |
|
| | |
| | monodepth_vit_type: str |
| |
|
| | |
| | local_mv_match: int |
| |
|
| |
|
| | class EncoderDepthSplat_test(Encoder[EncoderDepthSplatCfg]): |
| | def __init__(self, cfg: EncoderDepthSplatCfg) -> None: |
| | super().__init__(cfg) |
| |
|
| | self.depth_predictor = MultiViewUniMatch( |
| | num_scales=cfg.num_scales, |
| | upsample_factor=cfg.upsample_factor, |
| | lowest_feature_resolution=cfg.lowest_feature_resolution, |
| | vit_type=cfg.monodepth_vit_type, |
| | unet_channels=cfg.depth_unet_channels, |
| | grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, |
| | ) |
| |
|
| | if self.cfg.train_depth_only: |
| | return |
| |
|
| | |
| | model_configs = { |
| | 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, |
| | 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, |
| | 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, |
| | } |
| |
|
| | self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type], |
| | downsample_factor=cfg.upsample_factor, |
| | return_feature=True, |
| | num_scales=cfg.num_scales, |
| | ) |
| | feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"] |
| | |
| | |
| | self.gaussian_adapter = GaussianAdapter_depth(cfg.gaussian_adapter) |
| |
|
| | |
| | in_channels = 3 + 1 + 1 + feature_upsampler_channels |
| | channels = self.cfg.gaussian_regressor_channels |
| |
|
| | |
| | modules = [ |
| | nn.Conv2d(in_channels, channels, 3, 1, 1), |
| | nn.GELU(), |
| | nn.Conv2d(channels, channels, 3, 1, 1), |
| | ] |
| |
|
| | self.gaussian_regressor = nn.Sequential(*modules) |
| |
|
| | |
| | |
| | num_gaussian_parameters = self.gaussian_adapter.d_in + 3 + 1 |
| | |
| |
|
| | |
| | in_channels = 3 + feature_upsampler_channels + channels + 1 |
| | |
| | |
| | self.spare_unet =SparseUNetWithAttention( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | num_blocks=3, |
| | use_attention=False |
| | ) |
| | |
| | |
| | self.gaussian_head = SparseGaussianHead(in_channels, num_gaussian_parameters) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.scale = 0.04 |
| | self.shift = 0.01 |
| |
|
| | def forward( |
| | self, |
| | context: dict, |
| | global_step: int, |
| | deterministic: bool = False, |
| | visualization_dump: Optional[dict] = None, |
| | scene_names: Optional[list] = None, |
| | ues_voxelnet: bool = True, |
| | ): |
| | device = context["image"].device |
| | b, v, _, h, w = context["image"].shape |
| |
|
| | if v > 3: |
| | with torch.no_grad(): |
| | xyzs = context["extrinsics"][:, :, :3, -1].detach() |
| | cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) |
| | cameras_dist_index = torch.argsort(cameras_dist_matrix) |
| |
|
| | cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] |
| | else: |
| | cameras_dist_index = None |
| |
|
| |
|
| | results_dict = self.depth_predictor( |
| | context["image"], |
| | attn_splits_list=[2], |
| | min_depth=1. / context["far"], |
| | max_depth=1. / context["near"], |
| | intrinsics=context["intrinsics"], |
| | extrinsics=context["extrinsics"], |
| | nn_matrix=cameras_dist_index, |
| | ) |
| |
|
| | |
| | depth_preds = results_dict['depth_preds'] |
| | |
| | depth = depth_preds[-1] |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | voxel_resolution = 0.02 |
| | |
| | |
| | |
| | |
| | if self.cfg.train_depth_only: |
| | |
| | |
| | depths = rearrange(depth, "b v h w -> b v (h w) () ()") |
| |
|
| | if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: |
| | |
| | num_depths = len(depth_preds) |
| |
|
| | |
| | intermediate_depths = torch.cat( |
| | depth_preds[:(num_depths - 1)], dim=0) |
| | intermediate_depths = rearrange( |
| | intermediate_depths, "b v h w -> b v (h w) () ()") |
| |
|
| | |
| | depths = torch.cat((intermediate_depths, depths), dim=0) |
| |
|
| | b *= num_depths |
| |
|
| | |
| | depths = rearrange( |
| | depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w |
| | ).squeeze(-1).squeeze(-1) |
| | |
| |
|
| | return { |
| | "gaussians": None, |
| | "depths": depths |
| | } |
| |
|
| | |
| | features = self.feature_upsampler(results_dict["features_mono_intermediate"], |
| | cnn_features=results_dict["features_cnn_all_scales"][::-1], |
| | mv_features=results_dict["features_mv"][ |
| | 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1] |
| | ) |
| | |
| | |
| | |
| | match_prob = results_dict['match_probs'][-1] |
| | match_prob = torch.max(match_prob, dim=1, keepdim=True)[ |
| | 0] |
| | match_prob = F.interpolate( |
| | match_prob, size=depth.shape[-2:], mode='nearest') |
| | |
| | |
| | |
| | concat = torch.cat(( |
| | rearrange(context["image"], "b v c h w -> (b v) c h w"), |
| | rearrange(depth, "b v h w -> (b v) () h w"), |
| | match_prob, |
| | features, |
| | ), dim=1) |
| | |
| | out = self.gaussian_regressor(concat) |
| | concat = [out, |
| | rearrange(context["image"], |
| | "b v c h w -> (b v) c h w"), |
| | features, |
| | match_prob] |
| | |
| | out = torch.cat(concat, dim=1) |
| | |
| | sparse_input, aggregated_points, counts = project_features_to_me( |
| | context["intrinsics"], |
| | context["extrinsics"], |
| | out, |
| | depth=depth, |
| | voxel_resolution=voxel_resolution, |
| | b=b, v=v |
| | ) |
| |
|
| | sparse_out = self.spare_unet(sparse_input) |
| | |
| | if torch.equal(sparse_out.C, sparse_input.C) and sparse_out.F.shape[1] == sparse_input.F.shape[1]: |
| | |
| | new_features = sparse_out.F + sparse_input.F |
| | |
| | |
| | sparse_out_with_residual = ME.SparseTensor( |
| | features=new_features, |
| | coordinate_map_key=sparse_out.coordinate_map_key, |
| | coordinate_manager=sparse_out.coordinate_manager |
| | ) |
| | else: |
| | |
| | print("警告:输入和输出坐标不一致,跳过残差连接") |
| | sparse_out_with_residual = sparse_out |
| |
|
| | |
| | gaussians = self.gaussian_head(sparse_out_with_residual) |
| |
|
| | |
| | del sparse_out_with_residual,sparse_out,sparse_input,new_features |
| |
|
| | |
| | |
| | depths = rearrange(depth, "b v h w -> b v (h w) () ()") |
| | |
| | |
| | |
| | print(f"输出稀疏张量: {gaussians.F.shape[0]}个体素") |
| | |
| | gaussian_params = gaussians.F.unsqueeze(0).unsqueeze(0) |
| | |
| | |
| | opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) |
| | raw_gaussians = gaussian_params[..., 1:] |
| | raw_gaussians = rearrange( |
| | raw_gaussians, |
| | "... (srf c) -> ... srf c", |
| | srf=self.cfg.num_surfaces, |
| | ) |
| | |
| | try: |
| | |
| | gaussians = self.gaussian_adapter.forward( |
| | extrinsics = context["extrinsics"], |
| | intrinsics = context["intrinsics"], |
| | opacities = opacities, |
| | raw_gaussians = rearrange(raw_gaussians,"b v r srf c -> b v r srf () c"), |
| | input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), |
| | depth = depth, |
| | coordidate = gaussians.C, |
| | points = aggregated_points, |
| | voxel_resolution = voxel_resolution |
| | ) |
| | except Exception as e: |
| | import traceback; traceback.print_exc() |
| | raise |
| |
|
| | |
| |
|
| | if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: |
| | intermediate_depth = depth_preds[0] |
| | |
| | intermediate_voxel_feature, median_points, counts = project_features_to_me( |
| | context["intrinsics"], |
| | context["extrinsics"], |
| | out, |
| | depth=intermediate_depth, |
| | voxel_resolution=voxel_resolution, |
| | b=b, v=v |
| | ) |
| | |
| | intermediate_out = self.spare_unet(intermediate_voxel_feature) |
| | |
| | if torch.equal(intermediate_out.C, intermediate_voxel_feature.C) and intermediate_out.F.shape[1] == intermediate_voxel_feature.F.shape[1]: |
| | |
| | new_inter_features = intermediate_out.F + intermediate_voxel_feature.F |
| | |
| | |
| | intermedian_out_with_residual = ME.SparseTensor( |
| | features=new_inter_features, |
| | coordinate_map_key=intermediate_voxel_feature.coordinate_map_key, |
| | coordinate_manager=intermediate_voxel_feature.coordinate_manager |
| | ) |
| | else: |
| | |
| | print("警告:输入和输出坐标不一致,跳过残差连接") |
| | intermedian_out_with_residual = intermediate_voxel_feature |
| |
|
| | |
| | intermediate_gaussians = self.gaussian_head(intermedian_out_with_residual) |
| |
|
| | |
| | del intermediate_voxel_feature,intermediate_out,intermedian_out_with_residual |
| |
|
| | |
| | |
| | |
| | gaussian_params = intermediate_gaussians.F.unsqueeze(0).unsqueeze(0) |
| | |
| | |
| | intermediate_opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) |
| | intermediate_raw_gaussians = gaussian_params[..., 1:] |
| | intermediate_raw_gaussians = rearrange( |
| | intermediate_raw_gaussians, |
| | "... (srf c) -> ... srf c", |
| | srf=self.cfg.num_surfaces, |
| | ) |
| | |
| | |
| | |
| | intermediate_gaussians = self.gaussian_adapter.forward( |
| | extrinsics = context["extrinsics"], |
| | intrinsics = context["intrinsics"], |
| | opacities = intermediate_opacities, |
| | raw_gaussians = rearrange(intermediate_raw_gaussians,"b v r srf c -> b v r srf () c"), |
| | input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), |
| | depth = intermediate_depth, |
| | coordidate = intermediate_gaussians.C, |
| | points = median_points, |
| | voxel_resolution = voxel_resolution |
| | ) |
| | |
| | intermediate_gaussians = Gaussians( |
| | rearrange( |
| | intermediate_gaussians.means, |
| | "b v r srf spp xyz -> b (v r srf spp) xyz", |
| | ), |
| | rearrange( |
| | intermediate_gaussians.covariances, |
| | "b v r srf spp i j -> b (v r srf spp) i j", |
| | ), |
| | rearrange( |
| | intermediate_gaussians.harmonics, |
| | "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", |
| | ), |
| | rearrange( |
| | intermediate_gaussians.opacities, |
| | "b v r srf spp -> b (v r srf spp)", |
| | ), |
| | ) |
| | else: |
| | intermediate_gaussians = None |
| |
|
| |
|
| | |
| | gaussians = Gaussians( |
| | rearrange( |
| | gaussians.means, |
| | "b v r srf spp xyz -> b (v r srf spp) xyz", |
| | ), |
| | rearrange( |
| | gaussians.covariances, |
| | "b v r srf spp i j -> b (v r srf spp) i j", |
| | ), |
| | rearrange( |
| | gaussians.harmonics, |
| | "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", |
| | ), |
| | rearrange( |
| | gaussians.opacities, |
| | "b v r srf spp -> b (v r srf spp)", |
| | ), |
| | ) |
| |
|
| | |
| | if self.cfg.return_depth: |
| | |
| | |
| | depths = rearrange( |
| | depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w |
| | ).squeeze(-1).squeeze(-1) |
| | |
| | |
| | if intermediate_gaussians is not None: |
| | return { |
| | "gaussians": gaussians, |
| | "depths": depths, |
| | "intermediate_gaussians": intermediate_gaussians |
| | } |
| | else: |
| | return { |
| | "gaussians": gaussians, |
| | "depths": depths, |
| | } |
| |
|
| | return gaussians |
| |
|
| | def get_data_shim(self) -> DataShim: |
| | def data_shim(batch: BatchedExample) -> BatchedExample: |
| | batch = apply_patch_shim( |
| | batch, |
| | patch_size=self.cfg.shim_patch_size |
| | * self.cfg.downscale_factor, |
| | ) |
| |
|
| | return batch |
| |
|
| | return data_shim |
| |
|
| | @property |
| | def sampler(self): |
| | return None |
| |
|
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|