|
import json |
|
from pathlib import Path |
|
from typing import List, Tuple |
|
from math import ceil |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
from einops import rearrange |
|
|
|
from unidepth.models.unidepthv1 import UniDepthV1 |
|
from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD |
|
from unidepth.utils.geometric import ( |
|
generate_rays, |
|
spherical_zbuffer_to_euclidean, |
|
flat_interpolate, |
|
) |
|
from unidepth.layers import ( |
|
MLP, |
|
AttentionBlock, |
|
NystromBlock, |
|
PositionEmbeddingSine, |
|
ConvUpsample, |
|
) |
|
from unidepth.utils.sht import rsh_cart_8 |
|
|
|
from networks.gaussian_decoder import get_splits_and_inits |
|
|
|
|
|
|
|
def _paddings(image_shape, network_shape): |
|
cur_h, cur_w = image_shape |
|
h, w = network_shape |
|
pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 |
|
pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 |
|
return pad_left, pad_right, pad_top, pad_bottom |
|
|
|
|
|
def _shapes(image_shape, network_shape): |
|
h, w = image_shape |
|
input_ratio = w / h |
|
output_ratio = network_shape[1] / network_shape[0] |
|
if output_ratio > input_ratio: |
|
ratio = network_shape[0] / h |
|
elif output_ratio <= input_ratio: |
|
ratio = network_shape[1] / w |
|
return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio |
|
|
|
|
|
def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): |
|
(pad_left, pad_right, pad_top, pad_bottom) = pads |
|
rgbs = F.interpolate( |
|
rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True |
|
) |
|
rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") |
|
if intrinsics is not None: |
|
intrinsics = intrinsics.clone() |
|
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio |
|
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio |
|
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left |
|
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top |
|
return rgbs, intrinsics |
|
return rgbs, None |
|
|
|
|
|
def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): |
|
|
|
(pad_left, pad_right, pad_top, pad_bottom) = pads |
|
|
|
predictions = sum( |
|
[ |
|
F.interpolate( |
|
x, |
|
size=shapes, |
|
mode="bilinear", |
|
align_corners=False, |
|
antialias=True, |
|
) |
|
for x in predictions |
|
] |
|
) / len(predictions) |
|
|
|
shapes = predictions.shape[2:] |
|
predictions = predictions[ |
|
..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right |
|
] |
|
|
|
predictions = F.interpolate( |
|
predictions, |
|
size=original_shapes, |
|
mode="bilinear", |
|
align_corners=False, |
|
antialias=True, |
|
) |
|
|
|
if intrinsics is not None: |
|
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio |
|
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio |
|
intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio |
|
intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio |
|
|
|
return predictions, intrinsics |
|
|
|
|
|
def scale_intrinsics_xy(intrinsics, x_ratio, y_ratio): |
|
intrinsics = intrinsics.clone() |
|
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * x_ratio |
|
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * y_ratio |
|
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * x_ratio |
|
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * y_ratio |
|
return intrinsics |
|
|
|
|
|
def scale_intrinsics(intrinsics, ratio): |
|
intrinsics = intrinsics.clone() |
|
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio |
|
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio |
|
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio |
|
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio |
|
return intrinsics |
|
|
|
|
|
def unidepthv1_forward(model, rgbs, intrinsics, skip_camera, |
|
return_raw_preds=False): |
|
B, _, H, W = rgbs.shape |
|
|
|
rgbs = TF.normalize( |
|
rgbs, |
|
mean=IMAGENET_DATASET_MEAN, |
|
std=IMAGENET_DATASET_STD, |
|
) |
|
|
|
(h, w), ratio = _shapes((H, W), model.image_shape) |
|
pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), model.image_shape) |
|
rgbs, gt_intrinsics = _preprocess( |
|
rgbs, |
|
intrinsics, |
|
(h, w), |
|
(pad_left, pad_right, pad_top, pad_bottom), |
|
ratio, |
|
model.image_shape, |
|
) |
|
|
|
encoder_outputs, cls_tokens = model.pixel_encoder(rgbs) |
|
if "dino" in model.pixel_encoder.__class__.__name__.lower(): |
|
encoder_outputs = [ |
|
(x + y.unsqueeze(1)).contiguous() |
|
for x, y in zip(encoder_outputs, cls_tokens) |
|
] |
|
|
|
|
|
inputs = {} |
|
inputs["encoder_outputs"] = encoder_outputs |
|
inputs["cls_tokens"] = cls_tokens |
|
inputs["image"] = rgbs |
|
if gt_intrinsics is not None: |
|
rays, angles = generate_rays( |
|
gt_intrinsics, model.image_shape, noisy=False |
|
) |
|
inputs["rays"] = rays |
|
inputs["angles"] = angles |
|
inputs["K"] = gt_intrinsics |
|
model.pixel_decoder.test_fixed_camera = True |
|
model.pixel_decoder.skip_camera = skip_camera |
|
|
|
|
|
pred_intrinsics, predictions, features, rays = model.pixel_decoder(inputs, {}) |
|
|
|
pads = (pad_left, pad_right, pad_top, pad_bottom) |
|
|
|
|
|
predictions, pred_intrinsics = _postprocess( |
|
predictions, |
|
pred_intrinsics, |
|
model.image_shape, |
|
pads, |
|
ratio, |
|
(H, W), |
|
) |
|
|
|
if return_raw_preds: |
|
return inputs, predictions |
|
|
|
|
|
intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics |
|
angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] |
|
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) |
|
points_3d = torch.cat((angles, predictions), dim=1) |
|
points_3d = spherical_zbuffer_to_euclidean( |
|
points_3d.permute(0, 2, 3, 1) |
|
).permute(0, 3, 1, 2) |
|
|
|
|
|
outputs = { |
|
"intrinsics": intrinsics, |
|
"points": points_3d, |
|
"depth": predictions[:, -1:], |
|
"depth_feats": features, |
|
"rays": rays, |
|
"padding": pads |
|
} |
|
model.pixel_decoder.test_fixed_camera = False |
|
model.pixel_decoder.skip_camera = False |
|
return inputs, outputs |
|
|
|
class UniDepthDepth(nn.Module): |
|
def __init__( |
|
self, |
|
cfg, |
|
return_raw_preds=False |
|
): |
|
super().__init__() |
|
|
|
self.cfg = cfg |
|
self.return_raw_preds = return_raw_preds |
|
|
|
if "cnvnxtl" in cfg.model.name: |
|
self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl") |
|
elif "vit" in cfg.model.name: |
|
self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14") |
|
|
|
self.skip_camera = True |
|
|
|
def get_depth(self, img, intrinsics): |
|
depth_inputs, outputs = unidepthv1_forward( |
|
self.depth_prediction_model, |
|
img, |
|
intrinsics, |
|
self.skip_camera, |
|
return_raw_preds=self.return_raw_preds) |
|
return outputs |
|
|
|
def forward(self, inputs): |
|
input_img = inputs["color_aug", 0, 0] |
|
|
|
|
|
if ("K_src", 0) in inputs: |
|
intrinsics = inputs[("K_src", 0)] |
|
else: |
|
intrinsics = None |
|
|
|
depth_inputs, outputs = unidepthv1_forward( |
|
self.depth_prediction_model, |
|
input_img, |
|
intrinsics, |
|
self.skip_camera, |
|
return_raw_preds=self.return_raw_preds) |
|
|
|
return depth_inputs, outputs |
|
|
|
class UniDepthUnprojector(nn.Module): |
|
def __init__( |
|
self, |
|
cfg |
|
): |
|
super().__init__() |
|
|
|
self.cfg = cfg |
|
|
|
if cfg.model.name == "unidepth_unprojector_cnvnxtl": |
|
model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl") |
|
elif cfg.model.name == "unidepth_unprojector_vit": |
|
model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14") |
|
self.unidepth = model |
|
|
|
self.skip_camera = True |
|
|
|
self.register_buffer("gauss_opacity", torch.ones(1, 1, 1).float()) |
|
self.register_buffer("gauss_scaling", torch.ones(3, 1, 1).float()) |
|
self.register_buffer("gauss_rotation", torch.ones(4, 1, 1).float() * 0.5) |
|
self.register_buffer("gauss_features_rest", torch.zeros(9, 1, 1).float()) |
|
self.register_buffer("gauss_offset", torch.zeros(3, 1, 1).float()) |
|
|
|
self.all_params = nn.ParameterDict({ |
|
"opacity_scaling": nn.Parameter(torch.tensor(cfg.model.opacity_bias).float()), |
|
"scale_scaling": nn.Parameter(torch.tensor(cfg.model.scale_bias).float()), |
|
"colour_scaling": nn.Parameter(torch.tensor(self.cfg.model.colour_scale).float())}) |
|
|
|
|
|
self.scaling_activation = torch.exp |
|
self.opacity_activation = torch.sigmoid |
|
self.relu = nn.ReLU() |
|
|
|
def get_parameter_groups(self): |
|
|
|
return [{'params': self.all_params.parameters()}] |
|
|
|
def forward(self, inputs): |
|
model = self.unidepth |
|
input_img = inputs["color_aug", 0, 0] |
|
|
|
|
|
intrinsics = inputs[("K_src", 0)] |
|
b, c, h, w = inputs["color_aug", 0, 0].shape |
|
|
|
with torch.no_grad(): |
|
_, depth_outs = unidepthv1_forward(model, input_img, intrinsics, self.skip_camera) |
|
|
|
outs = {} |
|
|
|
outs[("gauss_opacity", 0)] = self.gauss_opacity.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \ |
|
* self.opacity_activation(self.all_params["opacity_scaling"]) |
|
if not self.cfg.model.scale_with_depth: |
|
outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \ |
|
* self.scaling_activation(self.all_params["scale_scaling"]) |
|
else: |
|
outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \ |
|
* self.scaling_activation(self.all_params["scale_scaling"]) * depth_outs["depth"] / 10.0 |
|
outs[("gauss_rotation", 0)] = self.gauss_rotation.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) |
|
outs[("gauss_offset", 0)] = self.gauss_offset.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) |
|
outs[("gauss_features_rest", 0)] = self.gauss_features_rest.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) |
|
|
|
outs[("gauss_features_dc", 0)] = (input_img - 0.5)* self.relu(self.all_params["colour_scaling"]) |
|
|
|
outs[("depth", 0)] = depth_outs["depth"] |
|
|
|
return outs |
|
|
|
class UniDepthSplatter(nn.Module): |
|
def __init__( |
|
self, |
|
cfg |
|
): |
|
super().__init__() |
|
|
|
self.cfg = cfg |
|
|
|
config_path = Path("/work/eldar/src/UniDepth") |
|
with open(config_path / "configs/config_v1_cnvnxtl.json") as f: |
|
config = json.load(f) |
|
self.unidepth = UniDepthDepth(self.cfg) |
|
|
|
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] |
|
expansion = config["model"]["expansion"] |
|
depth = config["model"]["pixel_decoder"]["depths"] |
|
num_heads = config["model"]["num_heads"] |
|
dropout = config["model"]["pixel_decoder"]["dropout"] |
|
layer_scale = 1.0 |
|
self.splat_decoder = GaussSplatHead( |
|
cfg, |
|
hidden_dim=hidden_dim, |
|
num_heads=num_heads, |
|
expansion=expansion, |
|
depths=depth, |
|
camera_dim=81, |
|
dropout=dropout, |
|
layer_scale=layer_scale, |
|
) |
|
|
|
self.skip_camera = True |
|
|
|
def get_parameter_groups(self): |
|
base_lr = self.cfg.optimiser.learning_rate |
|
return [ |
|
{'params': self.unidepth.parameters(), "lr": base_lr * 0.05}, |
|
{'params': self.splat_decoder.parameters()} |
|
] |
|
|
|
def forward(self, inputs): |
|
gauss_head = self.splat_decoder |
|
|
|
depth_inputs, depth_outs = self.unidepth(inputs) |
|
depth_feats = depth_outs["depth_feats"] |
|
rays = depth_outs["rays"] |
|
padding = depth_outs["padding"] |
|
|
|
B, _, H, W = depth_inputs["image"].shape |
|
|
|
|
|
common_shape = (28, 38) |
|
gauss_head.set_shapes(common_shape) |
|
gauss_head.set_original_shapes((H, W)) |
|
|
|
depth_feats = rearrange(depth_feats, "b c h w -> b (h w) c") |
|
outs = gauss_head( |
|
latents_16=depth_feats, |
|
rays_hr=rays, |
|
) |
|
for k, v in outs.items(): |
|
pred, _ = _postprocess([v], None, self.unidepth.depth_prediction_model.image_shape, |
|
padding, None, inputs["color_aug", 0, 0].shape[2:4]) |
|
outs[k] = pred |
|
outs[("depth", 0)] = depth_outs["depth"] |
|
|
|
return outs |
|
|
|
|
|
class GaussSplatHead(nn.Module): |
|
def __init__( |
|
self, |
|
cfg, |
|
hidden_dim: int, |
|
num_heads: int = 8, |
|
expansion: int = 4, |
|
depths: int | list[int] = 4, |
|
camera_dim: int = 256, |
|
dropout: float = 0.0, |
|
layer_scale: float = 1.0, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.cfg = cfg |
|
|
|
if isinstance(depths, int): |
|
depths = [depths] * 3 |
|
assert len(depths) == 3 |
|
|
|
self.project_rays16 = MLP( |
|
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim |
|
) |
|
self.project_rays8 = MLP( |
|
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2 |
|
) |
|
self.project_rays4 = MLP( |
|
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4 |
|
) |
|
|
|
self.layers_8 = nn.ModuleList([]) |
|
self.layers_4 = nn.ModuleList([]) |
|
layers_16 = nn.ModuleList([]) |
|
|
|
self.up8 = ConvUpsample( |
|
hidden_dim, expansion=expansion, layer_scale=layer_scale |
|
) |
|
self.up4 = ConvUpsample( |
|
hidden_dim // 2, expansion=expansion, layer_scale=layer_scale |
|
) |
|
self.up2 = ConvUpsample( |
|
hidden_dim // 4, expansion=expansion, layer_scale=layer_scale |
|
) |
|
|
|
split_dimensions, scale, bias = get_splits_and_inits(cfg) |
|
start = 1 |
|
self.split_dimensions = split_dimensions[start:] |
|
scale = scale[start:] |
|
bias = bias[start:] |
|
|
|
self.num_output_channels = sum(self.split_dimensions) |
|
|
|
self.out2 = nn.Conv2d(hidden_dim // 8, self.num_output_channels, 3, padding=1) |
|
|
|
|
|
|
|
start_channels = 0 |
|
for out_channel, b, s in zip(self.split_dimensions, bias, scale): |
|
nn.init.xavier_uniform_( |
|
self.out2.weight[start_channels:start_channels+out_channel, |
|
:, :, :], s) |
|
nn.init.constant_( |
|
self.out2.bias[start_channels:start_channels+out_channel], b) |
|
start_channels += out_channel |
|
|
|
for i, (blk_lst, depth) in enumerate( |
|
zip([layers_16, self.layers_8, self.layers_4], depths) |
|
): |
|
if i == 0: |
|
continue |
|
attn_cls = AttentionBlock if i == 0 else NystromBlock |
|
for _ in range(depth): |
|
blk_lst.append( |
|
attn_cls( |
|
hidden_dim // (2**i), |
|
num_heads=num_heads // (2**i), |
|
expansion=expansion, |
|
dropout=dropout, |
|
layer_scale=layer_scale, |
|
) |
|
) |
|
|
|
self.scaling_activation = torch.exp |
|
self.opacity_activation = torch.sigmoid |
|
self.rotation_activation = torch.nn.functional.normalize |
|
self.scaling_lambda = cfg.model.scale_lambda |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def set_original_shapes(self, shapes: Tuple[int, int]): |
|
self.original_shapes = shapes |
|
|
|
def set_shapes(self, shapes: Tuple[int, int]): |
|
self.shapes = shapes |
|
|
|
def forward( |
|
self, latents_16: torch.Tensor, rays_hr: torch.Tensor |
|
) -> torch.Tensor: |
|
shapes = self.shapes |
|
|
|
|
|
|
|
|
|
rays_embedding_16 = F.normalize( |
|
flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1 |
|
) |
|
rays_embedding_8 = F.normalize( |
|
flat_interpolate( |
|
rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes] |
|
), |
|
dim=-1, |
|
) |
|
rays_embedding_4 = F.normalize( |
|
flat_interpolate( |
|
rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes] |
|
), |
|
dim=-1, |
|
) |
|
rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16)) |
|
rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8)) |
|
rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4)) |
|
|
|
|
|
latents_8 = self.up8( |
|
rearrange( |
|
latents_16 + rays_embedding_16, |
|
"b (h w) c -> b c h w", |
|
h=shapes[0], |
|
w=shapes[1], |
|
).contiguous() |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer in self.layers_8: |
|
latents_8 = layer(latents_8, pos_embed=rays_embedding_8) |
|
latents_4 = self.up4( |
|
rearrange( |
|
latents_8 + rays_embedding_8, |
|
"b (h w) c -> b c h w", |
|
h=shapes[0] * 2, |
|
w=shapes[1] * 2, |
|
).contiguous() |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer in self.layers_4: |
|
latents_4 = layer(latents_4, pos_embed=rays_embedding_4) |
|
latents_2 = self.up2( |
|
rearrange( |
|
latents_4 + rays_embedding_4, |
|
"b (h w) c -> b c h w", |
|
h=shapes[0] * 4, |
|
w=shapes[1] * 4, |
|
).contiguous() |
|
) |
|
out2 = self.out2( |
|
rearrange( |
|
latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8 |
|
) |
|
) |
|
|
|
split_network_outputs = out2.split(self.split_dimensions, dim=1) |
|
last = 5 |
|
offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:last] |
|
|
|
out = { |
|
("gauss_opacity", 0): self.opacity_activation(opacity), |
|
("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda, |
|
("gauss_rotation", 0): self.rotation_activation(rotation), |
|
("gauss_features_dc", 0): feat_dc |
|
} |
|
|
|
if self.cfg.model.max_sh_degree > 0: |
|
features_rest = split_network_outputs[last] |
|
out[("gauss_features_rest", 0)] = features_rest |
|
|
|
if self.cfg.model.predict_offset: |
|
out[("gauss_offset", 0)] = offset |
|
|
|
return out |
|
|
|
|