smhh24's picture
Upload 90 files
560b597 verified
import importlib
import warnings
from copy import deepcopy
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
import torchvision.transforms as transforms
from unidepth.models.unidepthv2.decoder import Decoder
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD)
from unidepth.utils.distributed import is_main_process
from unidepth.utils.geometric import (generate_rays,
spherical_zbuffer_to_euclidean)
from unidepth.utils.misc import (first_stack, last_stack, max_stack,
mean_stack, softmax_stack)
STACKING_FNS = {
"max": max_stack,
"mean": mean_stack,
"first": first_stack,
"last": last_stack,
"softmax": softmax_stack,
}
RESOLUTION_LEVELS = 10
# inference helpers
def _check_ratio(image_ratio, ratio_bounds):
ratio_bounds = sorted(ratio_bounds)
if ratio_bounds is not None and (
image_ratio < ratio_bounds[0] or image_ratio > ratio_bounds[1]
):
warnings.warn(
f"Input image ratio ({image_ratio:.3f}) is out of training "
f"distribution: {ratio_bounds}. This may lead to unexpected results. "
f"Consider resizing/padding the image to match the training distribution."
)
def _check_resolution(shape_constraints, resolution_level):
if resolution_level is None:
warnings.warn(
"Resolution level is not set. Using max resolution. "
"You can tradeoff resolution for speed by setting a number in [0,10]. "
"This can be achieved by setting model's `resolution_level` attribute."
)
resolution_level = RESOLUTION_LEVELS
pixel_bounds = sorted(shape_constraints["pixels_bounds_ori"])
pixel_range = pixel_bounds[-1] - pixel_bounds[0]
clipped_resolution_level = min(max(resolution_level, 0), RESOLUTION_LEVELS)
if clipped_resolution_level != resolution_level:
warnings.warn(
f"Resolution level {resolution_level} is out of bounds ([0,{RESOLUTION_LEVELS}]). "
f"Clipping to {clipped_resolution_level}."
)
shape_constraints["pixels_bounds"] = [
pixel_bounds[0]
+ ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS),
pixel_bounds[0]
+ ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS),
]
return shape_constraints
def _get_closes_num_pixels(image_shape, pixels_bounds):
h, w = image_shape
num_pixels = h * w
pixels_bounds = sorted(pixels_bounds)
num_pixels = max(min(num_pixels, pixels_bounds[1]), pixels_bounds[0])
return num_pixels
def _shapes(image_shape, shape_constraints):
h, w = image_shape
image_ratio = w / h
# _check_ratio(image_ratio, shape_constraints["ratio_bounds"])
num_pixels = _get_closes_num_pixels(
(h / shape_constraints["patch_size"], w / shape_constraints["patch_size"]),
shape_constraints["pixels_bounds"],
)
h = ceil((num_pixels / image_ratio) ** 0.5 - 0.5)
w = ceil(h * image_ratio - 0.5)
ratio = h / image_shape[0] * shape_constraints["patch_size"]
return (
h * shape_constraints["patch_size"],
w * shape_constraints["patch_size"],
), ratio
def _preprocess(rgbs, intrinsics, shapes, ratio):
rgbs = F.interpolate(rgbs, size=shapes, mode="bilinear", antialias=True)
if intrinsics is not None:
intrinsics = intrinsics.clone()
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio
return rgbs, intrinsics
return rgbs, None
def _postprocess(outs, ratio, original_shapes, mode="nearest-exact"):
outs["depth"] = F.interpolate(outs["depth"], size=original_shapes, mode=mode)
outs["confidence"] = F.interpolate(
outs["confidence"], size=original_shapes, mode="bilinear", antialias=True
)
outs["K"][:, 0, 0] = outs["K"][:, 0, 0] / ratio
outs["K"][:, 1, 1] = outs["K"][:, 1, 1] / ratio
outs["K"][:, 0, 2] = outs["K"][:, 0, 2] / ratio
outs["K"][:, 1, 2] = outs["K"][:, 1, 2] / ratio
return outs
class UniDepthV2(
nn.Module,
PyTorchModelHubMixin,
library_name="UniDepth",
repo_url="https://github.com/lpiccinelli-eth/UniDepth",
tags=["monocular-metric-depth-estimation"],
):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__()
self.build(config)
self.interpolation_mode = "bilinear"
self.eps = eps
self.resolution_level = 10
def forward(self, inputs, image_metas=None):
H, W = inputs["depth"].shape[-2:]
if "K" in inputs:
rays, angles = generate_rays(inputs["K"], (H, W))
inputs["rays"] = rays
inputs["angles"] = angles
features, tokens = self.pixel_encoder(inputs[f"image"])
cls_tokens = [x.contiguous() for x in tokens]
features = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
tokens = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
global_tokens = [cls_tokens[i] for i in [-2, -1]]
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]]
inputs["features"] = features
inputs["tokens"] = tokens
inputs["global_tokens"] = global_tokens
inputs["camera_tokens"] = camera_tokens
outs = self.pixel_decoder(inputs, image_metas)
angles = rearrange(
generate_rays(outs["K"], (H, W), noisy=False)[-1],
"b (h w) c -> b c h w",
h=H,
w=W,
)
predictions = F.interpolate(
outs["depth"],
size=(H, W),
mode="bilinear",
align_corners=False,
antialias=True,
)
confidence = F.interpolate(
outs["confidence"],
size=(H, W),
mode="bilinear",
align_corners=False,
antialias=True,
)
predictions_3d = torch.cat((angles, predictions), dim=1)
predictions_3d = spherical_zbuffer_to_euclidean(
predictions_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
outputs = {
"K": outs["K"],
"depth": predictions.squeeze(1),
"confidence": confidence,
"points": predictions_3d,
"depth_features": outs["depth_features"],
}
return outputs
@torch.no_grad()
def infer(self, rgbs: torch.Tensor, intrinsics=None):
shape_constraints = self.shape_constraints
if rgbs.ndim == 3:
rgbs = rgbs.unsqueeze(0)
if intrinsics is not None and intrinsics.ndim == 2:
intrinsics = intrinsics.unsqueeze(0)
B, _, H, W = rgbs.shape
target_aspect_ratio = 1.33 # for example
# Calculate new width or height based on target aspect ratio
new_width = int(H * target_aspect_ratio)
# Resize the image
resize_transform = transforms.Resize((H, new_width)) # You can also pad if needed
rgbs = resize_transform(rgbs)
rgbs = rgbs.to(self.device)
if intrinsics is not None:
scale_width = new_width / W
# Adjust the intrinsic matrix
K_new = intrinsics.clone()
K_new[0, 0] = K_new[0, 0] * scale_width # f_x
K_new[0, 2] = K_new[0, 2] * scale_width # c_x
intrinsics = K_new.to(self.device)
# process image and intrinsiscs (if any) to match network input (slow?)
if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
rgbs = rgbs.to(torch.float32).div(255)
if rgbs.min() >= 0.0 and rgbs.max() <= 1.0:
rgbs = TF.normalize(
rgbs,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
# check resolution constraints: tradeoff resolution and speed
shape_constraints = _check_resolution(shape_constraints, self.resolution_level)
# get image shape
(h, w), ratio = _shapes((H, W), shape_constraints)
rgbs, gt_intrinsics = _preprocess(
rgbs,
intrinsics,
(h, w),
ratio,
)
# run encoder
features, tokens = self.pixel_encoder(rgbs)
cls_tokens = [x.contiguous() for x in tokens]
features = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
tokens = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
global_tokens = [cls_tokens[i] for i in [-2, -1]]
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]]
# get data fro decoder and adapt to given camera
inputs = {}
inputs["features"] = features
inputs["tokens"] = tokens
inputs["global_tokens"] = global_tokens
inputs["camera_tokens"] = camera_tokens
inputs["image"] = rgbs
if gt_intrinsics is not None:
rays, angles = generate_rays(gt_intrinsics, (h, w))
inputs["rays"] = rays
inputs["angles"] = angles
inputs["K"] = gt_intrinsics
outs = self.pixel_decoder(inputs, {})
# undo the reshaping and get original image size (slow)
outs = _postprocess(outs, ratio, (H, W), mode=self.interpolation_mode)
pred_intrinsics = outs["K"]
depth = outs["depth"]
confidence = outs["confidence"]
# final 3D points backprojection
intrinsics = intrinsics if intrinsics is not None else pred_intrinsics
angles = generate_rays(intrinsics, (H, W))[-1]
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
points_3d = torch.cat((angles, depth), dim=1)
points_3d = spherical_zbuffer_to_euclidean(
points_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
outputs = {
"intrinsics": pred_intrinsics,
"points": points_3d,
"depth": depth.squeeze(1),
"confidence": confidence,
}
return outputs
def load_pretrained(self, model_file):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
dict_model = torch.load(model_file, map_location=device)
if "model" in dict_model:
dict_model = dict_model["model"]
new_state_dict = deepcopy(
{k.replace("module.", ""): v for k, v in dict_model.items()}
)
info = self.load_state_dict(new_state_dict, strict=False)
if is_main_process():
print(
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
info,
)
@property
def device(self):
return next(self.parameters()).device
def build(self, config):
mod = importlib.import_module("unidepth.models.encoder")
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
pixel_encoder_config = {
**config["training"],
**config["data"],
**config["model"]["pixel_encoder"],
}
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
config["model"]["pixel_encoder"]["patch_size"] = (
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
)
pixel_encoder_embed_dims = (
pixel_encoder.embed_dims
if hasattr(pixel_encoder, "embed_dims")
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
)
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
pixel_encoder, "embed_dim"
)
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
pixel_decoder = Decoder(config)
self.pixel_encoder = pixel_encoder
self.pixel_decoder = pixel_decoder
stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"]
assert (
stacking_fn in STACKING_FNS
), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}"
self.stacking_fn = STACKING_FNS[stacking_fn]
self.slices_encoder_range = list(
zip([0, *pixel_encoder.depths[:-1]], pixel_encoder.depths)
)
self.shape_constraints = config["data"]["shape_constraints"]
self.shape_constraints["pixels_bounds_ori"] = self.shape_constraints.get(
"pixels_bounds", [1400, 2400]
)