Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Shariq Farooq Bhat, Zhenyu Li | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from tqdm.auto import tqdm | |
from torchvision.transforms import ToTensor, ToPILImage | |
from typing import List, Tuple | |
from PIL import Image | |
# from models.monodepth.zoedepth import ZoeDepthLora | |
# from zoedepth.utils.align.loss import SILogLoss, gradl1_loss, edge_aware_smoothness_per_pixel, ssim_loss | |
from .loss import * | |
import cv2 | |
from zoedepth.trainers.loss import * | |
# from utils.misc import * | |
def scale_shift_linear(rendered_depth, predicted_depth, mask, fuse=True, return_params=False): | |
""" | |
Optimize a scale and shift parameter in the least squares sense, such that rendered_depth and predicted_depth match. | |
Formally, solves the following objective: | |
min || (d * a + b) - d_hat || | |
a, b | |
where d = 1 / predicted_depth, d_hat = 1 / rendered_depth | |
:param rendered_depth: torch.Tensor (H, W) | |
:param predicted_depth: torch.Tensor (H, W) | |
:param mask: torch.Tensor (H, W) - 1: valid points of rendered_depth, 0: invalid points of rendered_depth (ignore) | |
:param fuse: whether to fuse shifted/scaled predicted_depth with the rendered_depth | |
:return: scale/shift corrected depth | |
""" | |
if mask.sum() == 0: | |
return predicted_depth | |
# rendered_disparity = 1 / rendered_depth[mask].unsqueeze(-1) | |
# predicted_disparity = 1 / predicted_depth[mask].unsqueeze(-1) | |
rendered_disparity = rendered_depth[mask].unsqueeze(-1) | |
predicted_disparity = predicted_depth[mask].unsqueeze(-1) | |
X = torch.cat([predicted_disparity, torch.ones_like(predicted_disparity)], dim=1) | |
XTX_inv = (X.T @ X).inverse() | |
XTY = X.T @ rendered_disparity | |
AB = XTX_inv @ XTY | |
if return_params: | |
return AB | |
fixed_disparity = (predicted_depth) * AB[0] + AB[1] | |
fixed_depth = fixed_disparity | |
if fuse: | |
fused_depth = torch.where(mask, rendered_depth, fixed_depth) | |
return fused_depth | |
else: | |
return fixed_depth | |
def np_scale_shift_linear(rendered_depth: np.ndarray, predicted_depth: np.ndarray, mask: np.ndarray, fuse: bool=True): | |
""" | |
Optimize a scale and shift parameter in the least squares sense, such that rendered_depth and predicted_depth match. | |
Formally, solves the following objective: | |
min || (d * a + b) - d_hat || | |
a, b | |
where d = predicted_depth, d_hat = rendered_depth | |
:param rendered_depth: np.ndarray (H, W) | |
:param predicted_depth: np.ndarray (H, W) | |
:param mask: np.ndarray (H, W) - 1: valid points of rendered_depth, 0: invalid points of rendered_depth (ignore) | |
:param fuse: whether to fuse shifted/scaled predicted_depth with the rendered_depth | |
:return: scale/shift corrected depth | |
""" | |
if mask.sum() == 0: | |
return predicted_depth | |
# rendered_disparity = 1 / rendered_depth[mask].reshape(-1, 1) | |
# predicted_disparity = 1 / predicted_depth[mask].reshape(-1, 1) | |
rendered_disparity = rendered_depth[mask].reshape(-1, 1) | |
predicted_disparity = predicted_depth[mask].reshape(-1, 1) | |
X = np.concatenate([predicted_disparity, np.ones_like(predicted_disparity)], axis=1) | |
XTX_inv = np.linalg.inv(X.T @ X) | |
XTY = X.T @ rendered_disparity | |
AB = XTX_inv @ XTY | |
fixed_disparity = (predicted_depth) * AB[0] + AB[1] | |
fixed_depth = fixed_disparity | |
if fuse: | |
fused_depth = np.where(mask, rendered_depth, fixed_depth) | |
return fused_depth | |
else: | |
return fixed_depth | |
def apply_depth_smoothing(depth, mask): | |
def dilate(x, k=3): | |
x = as_bchw_tensor(x.float(), 1) | |
x = torch.nn.functional.conv2d(x.float(), | |
torch.ones(1, 1, k, k).to(x.device), | |
padding="same" | |
) | |
return x.squeeze() > 0 | |
def sobel(x): | |
flipped_sobel_x = torch.tensor([ | |
[-1, 0, 1], | |
[-2, 0, 2], | |
[-1, 0, 1] | |
], dtype=torch.float32).to(x.device) | |
flipped_sobel_x = torch.stack([flipped_sobel_x, flipped_sobel_x.t()]).unsqueeze(1) | |
x_pad = torch.nn.functional.pad(x.float(), (1, 1, 1, 1), mode="replicate") | |
x = torch.nn.functional.conv2d( | |
x_pad, | |
flipped_sobel_x, | |
padding="valid" | |
) | |
dx, dy = x.unbind(dim=-3) | |
# return torch.sqrt(dx**2 + dy**2).squeeze() | |
# new content is created mostly in x direction, sharp edges in y direction are wanted (e.g. table --> wall) | |
return dx | |
depth = as_bchw_tensor(depth, 1) | |
mask = as_bchw_tensor(mask, 1).float() | |
edges = sobel(mask) | |
dilated_edges = dilate(edges, k=21) | |
depth_numpy = depth.squeeze().float().cpu().numpy() | |
blur_bilateral = cv2.bilateralFilter(depth_numpy, 5, 140, 140) | |
blur_gaussian = cv2.GaussianBlur(blur_bilateral, (5, 5), 0) | |
blur_gaussian = torch.from_numpy(blur_gaussian).to(depth) | |
# print("blur_gaussian", blur_gaussian.shape) | |
# plt.imshow(blur_gaussian.cpu().squeeze().numpy()) | |
# plt.title("depth smoothed whole") | |
# plt.show() | |
depth_smooth = torch.where(dilated_edges, blur_gaussian, depth) | |
return depth_smooth | |
def get_dilated_only_mask(mask: torch.Tensor, k=7): | |
x = as_bchw_tensor(mask.float(), 1) | |
x = torch.nn.functional.conv2d(x, torch.ones(1, 1, k, k).to(mask.device),padding="same") | |
dilated = x.squeeze() > 0 | |
dilated_only = dilated ^ mask | |
return dilated_only | |
def get_boundary_mask(mask: torch.Tensor, k=7): | |
return get_dilated_only_mask(mask, k=k) | get_dilated_only_mask(~mask, k=k) | |
def ss_align_and_blur(rendered_depth: torch.Tensor, predicted_depth: torch.Tensor, mask: torch.Tensor, fuse: bool=True): | |
aligned = scale_shift_linear(rendered_depth, predicted_depth, mask, fuse=fuse) | |
aligned = apply_depth_smoothing(aligned, mask) | |
return aligned | |
def np_ss_align_and_blur(rendered_depth: np.ndarray, predicted_depth: np.ndarray, mask: np.ndarray, fuse: bool=True): | |
aligned = np_scale_shift_linear(rendered_depth, predicted_depth, mask, fuse=fuse) | |
aligned = apply_depth_smoothing(aligned, mask).cpu().numpy() | |
return aligned | |
def stitch(depth_src: torch.Tensor, depth_target: torch.Tensor, mask_src: torch.Tensor, smoothen=True, device='cuda:0'): | |
depth_src = as_bchw_tensor(depth_src, 1, device=device) | |
depth_target = as_bchw_tensor(depth_target, 1, device=device) | |
mask_src = as_bchw_tensor(mask_src, 1, device=device) | |
stitched = depth_src * mask_src.float() + depth_target * (~mask_src).float() | |
# plt.imshow(stitched.cpu().squeeze().numpy()) | |
# plt.title("stitched before smoothing") | |
# plt.show() | |
# apply smoothing | |
if smoothen: | |
stitched = apply_depth_smoothing(stitched, mask_src).squeeze().float() | |
return stitched | |
def smoothness_loss(depth, mask=None): | |
depth_grad_x = torch.abs(depth[:, :, :, :-1] - depth[:, :, :, 1:]) | |
depth_grad_y = torch.abs(depth[:, :, :-1, :] - depth[:, :, 1:, :]) | |
if mask is not None: | |
return torch.mean(depth_grad_x[mask[:, :, :, :-1]]) + torch.mean(depth_grad_y[mask[:, :, :-1, :]]) | |
return torch.mean(depth_grad_x) + torch.mean(depth_grad_y) | |
import torch.optim as optim | |
from torch.optim import lr_scheduler | |
def finetune_on_sample(model, image_pil, target_depth, mask=None, | |
iters=10, lr=0.1, beta=0.5, w_boundary_grad=1, w_grad=0.1, gamma=0.99): | |
model.train() | |
model_device = next(model.parameters()).device | |
x = as_bchw_tensor(image_pil, 3, device=model_device) | |
target_depth = as_bchw_tensor(target_depth, 1, device=model_device) | |
if mask is None: | |
mask = target_depth > 0 | |
elif (not isinstance(mask, torch.Tensor)) or mask.shape != target_depth.shape: | |
mask = as_bchw_tensor(mask, 1, device=model_device).to(torch.bool) | |
history = [] | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) | |
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=iters, epochs=1) | |
# main_loss = nn.L1Loss() | |
main_loss = SILogLoss(beta=beta) | |
orig_y = model.infer(x, with_flip_aug=False).detach() | |
# scale, shift = scale_shift_linear(target_depth, orig_y, mask, return_params=True) | |
gl1 = gradl1_loss | |
pbar = tqdm(range(iters), desc="Finetuning on sample") | |
for i in pbar: | |
optimizer.zero_grad() | |
y = model.infer(x, with_flip_aug=False) | |
# y = y * scale + shift | |
stitched = y * (~mask).float() + (target_depth * (mask).float()).detach() | |
# loss = F.mse_loss(y[mask], target_depth[mask]) | |
loss_si = main_loss(y[mask], target_depth[mask]) | |
# loss = loss_si \ | |
# + wgrad * ( gl1(y, stitched) \ | |
# + 2*gl1(y, orig_y) ) \ | |
# + wboundary_smoothness * smoothness_loss(y, mask=get_boundary_mask(mask)) | |
loss_grad = gl1(y, orig_y) | |
bmask = get_boundary_mask(mask) | |
loss_boundary_grad = laplacian_matching_loss(stitched, orig_y, bmask) | |
loss = loss_si + w_boundary_grad * loss_boundary_grad + w_grad * loss_grad | |
# check if loss is nan | |
if torch.isnan(loss): | |
print("Loss is nan, breaking") | |
break | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
# history.append(loss.item()) | |
pbar.set_postfix(loss=loss.item(), si=loss_si.item()) | |
model.eval() | |
return model, history | |
# def align_by_finetuning_lora(model: ZoeDepthLora, image, target_depth, mask=None, iters=10, lr=0.1, gamma=0.99, **kwargs): | |
# # model.reset_lora() | |
# model.set_only_lora_trainable() | |
# model, history = finetune_on_sample(model, image, target_depth, mask=mask, iters=iters, lr=lr, gamma=gamma) | |
# aligned_depth = model.infer(as_bchw_tensor(image, 3, device=next(model.parameters()).device)) | |
# return dict(model=model, history=history, aligned_depth=aligned_depth) | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from utils.misc import as_bchw_tensor | |
def as_bchw_tensor(input_tensor, num, device): | |
input_tensor = torch.tensor(input_tensor).unsqueeze(dim=0).unsqueeze(dim=0).cuda() | |
return input_tensor | |
def optimize_depth_deformation(rendered_depth, pred_depth, mask, h=10, w=10, iters=100, init_lr=0.1, gamma=0.996, | |
init_deformation=None, | |
device='cuda:0'): | |
rendered_depth = as_bchw_tensor(rendered_depth, 1, device=device) | |
pred_depth = as_bchw_tensor(pred_depth, 1, device=device) | |
mask = as_bchw_tensor(mask, 1, device=device).to(torch.bool) | |
# initialize a grid of scalar values (with zeros) that will be optimized | |
# to deform the depth map | |
if init_deformation is None: | |
deformation = torch.zeros((1,1,h,w), requires_grad=True, device=device) | |
else: | |
deformation = init_deformation | |
deformation.requires_grad = True | |
assert deformation.shape == (1,1,h,w) | |
optimizer = torch.optim.Adam([deformation], lr=init_lr) | |
# exponential LR schedule | |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) | |
# optimize the deformation | |
history = [] | |
grad_loss = GradL1Loss() | |
for i in tqdm(range(iters)): | |
scalar_deformation = torch.exp(deformation) | |
scalar_deformation = F.interpolate(scalar_deformation, size=pred_depth.shape[-2:], mode='bilinear', align_corners=True) | |
adjusted_depth = pred_depth * scalar_deformation | |
loss = F.mse_loss(adjusted_depth[mask], rendered_depth[mask], reduction='none') | |
loss_g = grad_loss(adjusted_depth, rendered_depth, mask) | |
loss = loss.mean() + 0.1*loss_g | |
# loss = loss.mean() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
if i % 10 == 0: | |
history.append(loss.item()) | |
scalar_deformation = torch.exp(deformation) | |
scalar_deformation = F.interpolate(scalar_deformation, size=pred_depth.shape[-2:], mode='bilinear', align_corners=True) | |
adjusted_depth = pred_depth * scalar_deformation | |
# return dict(aligned_depth=adjusted_depth.detach().cpu().numpy().squeeze(), | |
# history=history, | |
# deformation=deformation) | |
return adjusted_depth.detach().cpu().squeeze() | |
def stage_wise_optimization(rendered_depth, pred_depth, mask, | |
stages=[(4,4), (8,8), (16,16), (32,32)], | |
iters=100, init_lr=0.1, gamma=0.996, device='cuda:1'): | |
h_init, w_init = stages[0] | |
init_deformation = torch.zeros((1,1,h_init,w_init), device=device) | |
result = optimize_depth_deformation(rendered_depth, pred_depth, mask, h=h_init, w=w_init, iters=iters, init_lr=init_lr, gamma=gamma, init_deformation=init_deformation, device=device) | |
init_deformation = result['deformation'] | |
history_stages = [result['history']] | |
for h, w in stages[1:]: | |
init_deformation = F.interpolate(init_deformation, size=(h,w), mode='bilinear', align_corners=True).detach() | |
result = optimize_depth_deformation(rendered_depth, pred_depth, mask, h=h, w=w, iters=iters, init_lr=init_lr, gamma=gamma, init_deformation=init_deformation, device=device) | |
init_deformation = result['deformation'] | |
history_stages.append(result['history']) | |
init_lr *= gamma**2 | |
return dict(aligned_depth=result['aligned_depth'], history_stages=history_stages) | |