Spaces:
Sleeping
Sleeping
import importlib | |
import warnings | |
from copy import deepcopy | |
from math import ceil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as TF | |
from einops import rearrange | |
from huggingface_hub import PyTorchModelHubMixin | |
import torchvision.transforms as transforms | |
from unidepth.models.unidepthv2.decoder import Decoder | |
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, | |
IMAGENET_DATASET_STD) | |
from unidepth.utils.distributed import is_main_process | |
from unidepth.utils.geometric import (generate_rays, | |
spherical_zbuffer_to_euclidean) | |
from unidepth.utils.misc import (first_stack, last_stack, max_stack, | |
mean_stack, softmax_stack) | |
STACKING_FNS = { | |
"max": max_stack, | |
"mean": mean_stack, | |
"first": first_stack, | |
"last": last_stack, | |
"softmax": softmax_stack, | |
} | |
RESOLUTION_LEVELS = 10 | |
# inference helpers | |
def _check_ratio(image_ratio, ratio_bounds): | |
ratio_bounds = sorted(ratio_bounds) | |
if ratio_bounds is not None and ( | |
image_ratio < ratio_bounds[0] or image_ratio > ratio_bounds[1] | |
): | |
warnings.warn( | |
f"Input image ratio ({image_ratio:.3f}) is out of training " | |
f"distribution: {ratio_bounds}. This may lead to unexpected results. " | |
f"Consider resizing/padding the image to match the training distribution." | |
) | |
def _check_resolution(shape_constraints, resolution_level): | |
if resolution_level is None: | |
warnings.warn( | |
"Resolution level is not set. Using max resolution. " | |
"You can tradeoff resolution for speed by setting a number in [0,10]. " | |
"This can be achieved by setting model's `resolution_level` attribute." | |
) | |
resolution_level = RESOLUTION_LEVELS | |
pixel_bounds = sorted(shape_constraints["pixels_bounds_ori"]) | |
pixel_range = pixel_bounds[-1] - pixel_bounds[0] | |
clipped_resolution_level = min(max(resolution_level, 0), RESOLUTION_LEVELS) | |
if clipped_resolution_level != resolution_level: | |
warnings.warn( | |
f"Resolution level {resolution_level} is out of bounds ([0,{RESOLUTION_LEVELS}]). " | |
f"Clipping to {clipped_resolution_level}." | |
) | |
shape_constraints["pixels_bounds"] = [ | |
pixel_bounds[0] | |
+ ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS), | |
pixel_bounds[0] | |
+ ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS), | |
] | |
return shape_constraints | |
def _get_closes_num_pixels(image_shape, pixels_bounds): | |
h, w = image_shape | |
num_pixels = h * w | |
pixels_bounds = sorted(pixels_bounds) | |
num_pixels = max(min(num_pixels, pixels_bounds[1]), pixels_bounds[0]) | |
return num_pixels | |
def _shapes(image_shape, shape_constraints): | |
h, w = image_shape | |
image_ratio = w / h | |
# _check_ratio(image_ratio, shape_constraints["ratio_bounds"]) | |
num_pixels = _get_closes_num_pixels( | |
(h / shape_constraints["patch_size"], w / shape_constraints["patch_size"]), | |
shape_constraints["pixels_bounds"], | |
) | |
h = ceil((num_pixels / image_ratio) ** 0.5 - 0.5) | |
w = ceil(h * image_ratio - 0.5) | |
ratio = h / image_shape[0] * shape_constraints["patch_size"] | |
return ( | |
h * shape_constraints["patch_size"], | |
w * shape_constraints["patch_size"], | |
), ratio | |
def _preprocess(rgbs, intrinsics, shapes, ratio): | |
rgbs = F.interpolate(rgbs, size=shapes, mode="bilinear", antialias=True) | |
if intrinsics is not None: | |
intrinsics = intrinsics.clone() | |
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio | |
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio | |
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio | |
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio | |
return rgbs, intrinsics | |
return rgbs, None | |
def _postprocess(outs, ratio, original_shapes, mode="nearest-exact"): | |
outs["depth"] = F.interpolate(outs["depth"], size=original_shapes, mode=mode) | |
outs["confidence"] = F.interpolate( | |
outs["confidence"], size=original_shapes, mode="bilinear", antialias=True | |
) | |
outs["K"][:, 0, 0] = outs["K"][:, 0, 0] / ratio | |
outs["K"][:, 1, 1] = outs["K"][:, 1, 1] / ratio | |
outs["K"][:, 0, 2] = outs["K"][:, 0, 2] / ratio | |
outs["K"][:, 1, 2] = outs["K"][:, 1, 2] / ratio | |
return outs | |
class UniDepthV2( | |
nn.Module, | |
PyTorchModelHubMixin, | |
library_name="UniDepth", | |
repo_url="https://github.com/lpiccinelli-eth/UniDepth", | |
tags=["monocular-metric-depth-estimation"], | |
): | |
def __init__( | |
self, | |
config, | |
eps: float = 1e-6, | |
**kwargs, | |
): | |
super().__init__() | |
self.build(config) | |
self.interpolation_mode = "bilinear" | |
self.eps = eps | |
self.resolution_level = 10 | |
def forward(self, inputs, image_metas=None): | |
H, W = inputs["depth"].shape[-2:] | |
if "K" in inputs: | |
rays, angles = generate_rays(inputs["K"], (H, W)) | |
inputs["rays"] = rays | |
inputs["angles"] = angles | |
features, tokens = self.pixel_encoder(inputs[f"image"]) | |
cls_tokens = [x.contiguous() for x in tokens] | |
features = [ | |
self.stacking_fn(features[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
tokens = [ | |
self.stacking_fn(tokens[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
global_tokens = [cls_tokens[i] for i in [-2, -1]] | |
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] | |
inputs["features"] = features | |
inputs["tokens"] = tokens | |
inputs["global_tokens"] = global_tokens | |
inputs["camera_tokens"] = camera_tokens | |
outs = self.pixel_decoder(inputs, image_metas) | |
angles = rearrange( | |
generate_rays(outs["K"], (H, W), noisy=False)[-1], | |
"b (h w) c -> b c h w", | |
h=H, | |
w=W, | |
) | |
predictions = F.interpolate( | |
outs["depth"], | |
size=(H, W), | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
confidence = F.interpolate( | |
outs["confidence"], | |
size=(H, W), | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
predictions_3d = torch.cat((angles, predictions), dim=1) | |
predictions_3d = spherical_zbuffer_to_euclidean( | |
predictions_3d.permute(0, 2, 3, 1) | |
).permute(0, 3, 1, 2) | |
outputs = { | |
"K": outs["K"], | |
"depth": predictions.squeeze(1), | |
"confidence": confidence, | |
"points": predictions_3d, | |
"depth_features": outs["depth_features"], | |
} | |
return outputs | |
def infer(self, rgbs: torch.Tensor, intrinsics=None): | |
shape_constraints = self.shape_constraints | |
if rgbs.ndim == 3: | |
rgbs = rgbs.unsqueeze(0) | |
if intrinsics is not None and intrinsics.ndim == 2: | |
intrinsics = intrinsics.unsqueeze(0) | |
B, _, H, W = rgbs.shape | |
target_aspect_ratio = 1.33 # for example | |
# Calculate new width or height based on target aspect ratio | |
new_width = int(H * target_aspect_ratio) | |
# Resize the image | |
resize_transform = transforms.Resize((H, new_width)) # You can also pad if needed | |
rgbs = resize_transform(rgbs) | |
rgbs = rgbs.to(self.device) | |
if intrinsics is not None: | |
scale_width = new_width / W | |
# Adjust the intrinsic matrix | |
K_new = intrinsics.clone() | |
K_new[0, 0] = K_new[0, 0] * scale_width # f_x | |
K_new[0, 2] = K_new[0, 2] * scale_width # c_x | |
intrinsics = K_new.to(self.device) | |
# process image and intrinsiscs (if any) to match network input (slow?) | |
if rgbs.max() > 5 or rgbs.dtype == torch.uint8: | |
rgbs = rgbs.to(torch.float32).div(255) | |
if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: | |
rgbs = TF.normalize( | |
rgbs, | |
mean=IMAGENET_DATASET_MEAN, | |
std=IMAGENET_DATASET_STD, | |
) | |
# check resolution constraints: tradeoff resolution and speed | |
shape_constraints = _check_resolution(shape_constraints, self.resolution_level) | |
# get image shape | |
(h, w), ratio = _shapes((H, W), shape_constraints) | |
rgbs, gt_intrinsics = _preprocess( | |
rgbs, | |
intrinsics, | |
(h, w), | |
ratio, | |
) | |
# run encoder | |
features, tokens = self.pixel_encoder(rgbs) | |
cls_tokens = [x.contiguous() for x in tokens] | |
features = [ | |
self.stacking_fn(features[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
tokens = [ | |
self.stacking_fn(tokens[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
global_tokens = [cls_tokens[i] for i in [-2, -1]] | |
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] | |
# get data fro decoder and adapt to given camera | |
inputs = {} | |
inputs["features"] = features | |
inputs["tokens"] = tokens | |
inputs["global_tokens"] = global_tokens | |
inputs["camera_tokens"] = camera_tokens | |
inputs["image"] = rgbs | |
if gt_intrinsics is not None: | |
rays, angles = generate_rays(gt_intrinsics, (h, w)) | |
inputs["rays"] = rays | |
inputs["angles"] = angles | |
inputs["K"] = gt_intrinsics | |
outs = self.pixel_decoder(inputs, {}) | |
# undo the reshaping and get original image size (slow) | |
outs = _postprocess(outs, ratio, (H, W), mode=self.interpolation_mode) | |
pred_intrinsics = outs["K"] | |
depth = outs["depth"] | |
confidence = outs["confidence"] | |
# final 3D points backprojection | |
intrinsics = intrinsics if intrinsics is not None else pred_intrinsics | |
angles = generate_rays(intrinsics, (H, W))[-1] | |
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) | |
points_3d = torch.cat((angles, depth), dim=1) | |
points_3d = spherical_zbuffer_to_euclidean( | |
points_3d.permute(0, 2, 3, 1) | |
).permute(0, 3, 1, 2) | |
outputs = { | |
"intrinsics": pred_intrinsics, | |
"points": points_3d, | |
"depth": depth.squeeze(1), | |
"confidence": confidence, | |
} | |
return outputs | |
def load_pretrained(self, model_file): | |
device = ( | |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
) | |
dict_model = torch.load(model_file, map_location=device) | |
if "model" in dict_model: | |
dict_model = dict_model["model"] | |
new_state_dict = deepcopy( | |
{k.replace("module.", ""): v for k, v in dict_model.items()} | |
) | |
info = self.load_state_dict(new_state_dict, strict=False) | |
if is_main_process(): | |
print( | |
f"Loaded from {model_file} for {self.__class__.__name__} results in:", | |
info, | |
) | |
def device(self): | |
return next(self.parameters()).device | |
def build(self, config): | |
mod = importlib.import_module("unidepth.models.encoder") | |
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) | |
pixel_encoder_config = { | |
**config["training"], | |
**config["data"], | |
**config["model"]["pixel_encoder"], | |
} | |
pixel_encoder = pixel_encoder_factory(pixel_encoder_config) | |
config["model"]["pixel_encoder"]["patch_size"] = ( | |
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 | |
) | |
pixel_encoder_embed_dims = ( | |
pixel_encoder.embed_dims | |
if hasattr(pixel_encoder, "embed_dims") | |
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] | |
) | |
config["model"]["pixel_encoder"]["embed_dim"] = getattr( | |
pixel_encoder, "embed_dim" | |
) | |
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims | |
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths | |
pixel_decoder = Decoder(config) | |
self.pixel_encoder = pixel_encoder | |
self.pixel_decoder = pixel_decoder | |
stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"] | |
assert ( | |
stacking_fn in STACKING_FNS | |
), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}" | |
self.stacking_fn = STACKING_FNS[stacking_fn] | |
self.slices_encoder_range = list( | |
zip([0, *pixel_encoder.depths[:-1]], pixel_encoder.depths) | |
) | |
self.shape_constraints = config["data"]["shape_constraints"] | |
self.shape_constraints["pixels_bounds_ori"] = self.shape_constraints.get( | |
"pixels_bounds", [1400, 2400] | |
) | |