PatchFusion / zoedepth /utils /align /depth_alignment.py
Zhenyu Li
update
78ab311
raw
history blame
14.7 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Shariq Farooq Bhat, Zhenyu Li
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.auto import tqdm
from torchvision.transforms import ToTensor, ToPILImage
from typing import List, Tuple
from PIL import Image
# from models.monodepth.zoedepth import ZoeDepthLora
# from zoedepth.utils.align.loss import SILogLoss, gradl1_loss, edge_aware_smoothness_per_pixel, ssim_loss
from .loss import *
import cv2
from zoedepth.trainers.loss import *
# from utils.misc import *
@torch.no_grad()
def scale_shift_linear(rendered_depth, predicted_depth, mask, fuse=True, return_params=False):
"""
Optimize a scale and shift parameter in the least squares sense, such that rendered_depth and predicted_depth match.
Formally, solves the following objective:
min || (d * a + b) - d_hat ||
a, b
where d = 1 / predicted_depth, d_hat = 1 / rendered_depth
:param rendered_depth: torch.Tensor (H, W)
:param predicted_depth: torch.Tensor (H, W)
:param mask: torch.Tensor (H, W) - 1: valid points of rendered_depth, 0: invalid points of rendered_depth (ignore)
:param fuse: whether to fuse shifted/scaled predicted_depth with the rendered_depth
:return: scale/shift corrected depth
"""
if mask.sum() == 0:
return predicted_depth
# rendered_disparity = 1 / rendered_depth[mask].unsqueeze(-1)
# predicted_disparity = 1 / predicted_depth[mask].unsqueeze(-1)
rendered_disparity = rendered_depth[mask].unsqueeze(-1)
predicted_disparity = predicted_depth[mask].unsqueeze(-1)
X = torch.cat([predicted_disparity, torch.ones_like(predicted_disparity)], dim=1)
XTX_inv = (X.T @ X).inverse()
XTY = X.T @ rendered_disparity
AB = XTX_inv @ XTY
if return_params:
return AB
fixed_disparity = (predicted_depth) * AB[0] + AB[1]
fixed_depth = fixed_disparity
if fuse:
fused_depth = torch.where(mask, rendered_depth, fixed_depth)
return fused_depth
else:
return fixed_depth
def np_scale_shift_linear(rendered_depth: np.ndarray, predicted_depth: np.ndarray, mask: np.ndarray, fuse: bool=True):
"""
Optimize a scale and shift parameter in the least squares sense, such that rendered_depth and predicted_depth match.
Formally, solves the following objective:
min || (d * a + b) - d_hat ||
a, b
where d = predicted_depth, d_hat = rendered_depth
:param rendered_depth: np.ndarray (H, W)
:param predicted_depth: np.ndarray (H, W)
:param mask: np.ndarray (H, W) - 1: valid points of rendered_depth, 0: invalid points of rendered_depth (ignore)
:param fuse: whether to fuse shifted/scaled predicted_depth with the rendered_depth
:return: scale/shift corrected depth
"""
if mask.sum() == 0:
return predicted_depth
# rendered_disparity = 1 / rendered_depth[mask].reshape(-1, 1)
# predicted_disparity = 1 / predicted_depth[mask].reshape(-1, 1)
rendered_disparity = rendered_depth[mask].reshape(-1, 1)
predicted_disparity = predicted_depth[mask].reshape(-1, 1)
X = np.concatenate([predicted_disparity, np.ones_like(predicted_disparity)], axis=1)
XTX_inv = np.linalg.inv(X.T @ X)
XTY = X.T @ rendered_disparity
AB = XTX_inv @ XTY
fixed_disparity = (predicted_depth) * AB[0] + AB[1]
fixed_depth = fixed_disparity
if fuse:
fused_depth = np.where(mask, rendered_depth, fixed_depth)
return fused_depth
else:
return fixed_depth
@torch.no_grad()
def apply_depth_smoothing(depth, mask):
def dilate(x, k=3):
x = as_bchw_tensor(x.float(), 1)
x = torch.nn.functional.conv2d(x.float(),
torch.ones(1, 1, k, k).to(x.device),
padding="same"
)
return x.squeeze() > 0
def sobel(x):
flipped_sobel_x = torch.tensor([
[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]
], dtype=torch.float32).to(x.device)
flipped_sobel_x = torch.stack([flipped_sobel_x, flipped_sobel_x.t()]).unsqueeze(1)
x_pad = torch.nn.functional.pad(x.float(), (1, 1, 1, 1), mode="replicate")
x = torch.nn.functional.conv2d(
x_pad,
flipped_sobel_x,
padding="valid"
)
dx, dy = x.unbind(dim=-3)
# return torch.sqrt(dx**2 + dy**2).squeeze()
# new content is created mostly in x direction, sharp edges in y direction are wanted (e.g. table --> wall)
return dx
depth = as_bchw_tensor(depth, 1)
mask = as_bchw_tensor(mask, 1).float()
edges = sobel(mask)
dilated_edges = dilate(edges, k=21)
depth_numpy = depth.squeeze().float().cpu().numpy()
blur_bilateral = cv2.bilateralFilter(depth_numpy, 5, 140, 140)
blur_gaussian = cv2.GaussianBlur(blur_bilateral, (5, 5), 0)
blur_gaussian = torch.from_numpy(blur_gaussian).to(depth)
# print("blur_gaussian", blur_gaussian.shape)
# plt.imshow(blur_gaussian.cpu().squeeze().numpy())
# plt.title("depth smoothed whole")
# plt.show()
depth_smooth = torch.where(dilated_edges, blur_gaussian, depth)
return depth_smooth
def get_dilated_only_mask(mask: torch.Tensor, k=7):
x = as_bchw_tensor(mask.float(), 1)
x = torch.nn.functional.conv2d(x, torch.ones(1, 1, k, k).to(mask.device),padding="same")
dilated = x.squeeze() > 0
dilated_only = dilated ^ mask
return dilated_only
def get_boundary_mask(mask: torch.Tensor, k=7):
return get_dilated_only_mask(mask, k=k) | get_dilated_only_mask(~mask, k=k)
@torch.no_grad()
def ss_align_and_blur(rendered_depth: torch.Tensor, predicted_depth: torch.Tensor, mask: torch.Tensor, fuse: bool=True):
aligned = scale_shift_linear(rendered_depth, predicted_depth, mask, fuse=fuse)
aligned = apply_depth_smoothing(aligned, mask)
return aligned
def np_ss_align_and_blur(rendered_depth: np.ndarray, predicted_depth: np.ndarray, mask: np.ndarray, fuse: bool=True):
aligned = np_scale_shift_linear(rendered_depth, predicted_depth, mask, fuse=fuse)
aligned = apply_depth_smoothing(aligned, mask).cpu().numpy()
return aligned
def stitch(depth_src: torch.Tensor, depth_target: torch.Tensor, mask_src: torch.Tensor, smoothen=True, device='cuda:0'):
depth_src = as_bchw_tensor(depth_src, 1, device=device)
depth_target = as_bchw_tensor(depth_target, 1, device=device)
mask_src = as_bchw_tensor(mask_src, 1, device=device)
stitched = depth_src * mask_src.float() + depth_target * (~mask_src).float()
# plt.imshow(stitched.cpu().squeeze().numpy())
# plt.title("stitched before smoothing")
# plt.show()
# apply smoothing
if smoothen:
stitched = apply_depth_smoothing(stitched, mask_src).squeeze().float()
return stitched
def smoothness_loss(depth, mask=None):
depth_grad_x = torch.abs(depth[:, :, :, :-1] - depth[:, :, :, 1:])
depth_grad_y = torch.abs(depth[:, :, :-1, :] - depth[:, :, 1:, :])
if mask is not None:
return torch.mean(depth_grad_x[mask[:, :, :, :-1]]) + torch.mean(depth_grad_y[mask[:, :, :-1, :]])
return torch.mean(depth_grad_x) + torch.mean(depth_grad_y)
import torch.optim as optim
from torch.optim import lr_scheduler
def finetune_on_sample(model, image_pil, target_depth, mask=None,
iters=10, lr=0.1, beta=0.5, w_boundary_grad=1, w_grad=0.1, gamma=0.99):
model.train()
model_device = next(model.parameters()).device
x = as_bchw_tensor(image_pil, 3, device=model_device)
target_depth = as_bchw_tensor(target_depth, 1, device=model_device)
if mask is None:
mask = target_depth > 0
elif (not isinstance(mask, torch.Tensor)) or mask.shape != target_depth.shape:
mask = as_bchw_tensor(mask, 1, device=model_device).to(torch.bool)
history = []
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=iters, epochs=1)
# main_loss = nn.L1Loss()
main_loss = SILogLoss(beta=beta)
orig_y = model.infer(x, with_flip_aug=False).detach()
# scale, shift = scale_shift_linear(target_depth, orig_y, mask, return_params=True)
gl1 = gradl1_loss
pbar = tqdm(range(iters), desc="Finetuning on sample")
for i in pbar:
optimizer.zero_grad()
y = model.infer(x, with_flip_aug=False)
# y = y * scale + shift
stitched = y * (~mask).float() + (target_depth * (mask).float()).detach()
# loss = F.mse_loss(y[mask], target_depth[mask])
loss_si = main_loss(y[mask], target_depth[mask])
# loss = loss_si \
# + wgrad * ( gl1(y, stitched) \
# + 2*gl1(y, orig_y) ) \
# + wboundary_smoothness * smoothness_loss(y, mask=get_boundary_mask(mask))
loss_grad = gl1(y, orig_y)
bmask = get_boundary_mask(mask)
loss_boundary_grad = laplacian_matching_loss(stitched, orig_y, bmask)
loss = loss_si + w_boundary_grad * loss_boundary_grad + w_grad * loss_grad
# check if loss is nan
if torch.isnan(loss):
print("Loss is nan, breaking")
break
loss.backward()
optimizer.step()
scheduler.step()
# history.append(loss.item())
pbar.set_postfix(loss=loss.item(), si=loss_si.item())
model.eval()
return model, history
# def align_by_finetuning_lora(model: ZoeDepthLora, image, target_depth, mask=None, iters=10, lr=0.1, gamma=0.99, **kwargs):
# # model.reset_lora()
# model.set_only_lora_trainable()
# model, history = finetune_on_sample(model, image, target_depth, mask=mask, iters=iters, lr=lr, gamma=gamma)
# aligned_depth = model.infer(as_bchw_tensor(image, 3, device=next(model.parameters()).device))
# return dict(model=model, history=history, aligned_depth=aligned_depth)
import torch.nn as nn
import torch.nn.functional as F
# from utils.misc import as_bchw_tensor
def as_bchw_tensor(input_tensor, num, device):
input_tensor = torch.tensor(input_tensor).unsqueeze(dim=0).unsqueeze(dim=0).cuda()
return input_tensor
def optimize_depth_deformation(rendered_depth, pred_depth, mask, h=10, w=10, iters=100, init_lr=0.1, gamma=0.996,
init_deformation=None,
device='cuda:0'):
rendered_depth = as_bchw_tensor(rendered_depth, 1, device=device)
pred_depth = as_bchw_tensor(pred_depth, 1, device=device)
mask = as_bchw_tensor(mask, 1, device=device).to(torch.bool)
# initialize a grid of scalar values (with zeros) that will be optimized
# to deform the depth map
if init_deformation is None:
deformation = torch.zeros((1,1,h,w), requires_grad=True, device=device)
else:
deformation = init_deformation
deformation.requires_grad = True
assert deformation.shape == (1,1,h,w)
optimizer = torch.optim.Adam([deformation], lr=init_lr)
# exponential LR schedule
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
# optimize the deformation
history = []
grad_loss = GradL1Loss()
for i in tqdm(range(iters)):
scalar_deformation = torch.exp(deformation)
scalar_deformation = F.interpolate(scalar_deformation, size=pred_depth.shape[-2:], mode='bilinear', align_corners=True)
adjusted_depth = pred_depth * scalar_deformation
loss = F.mse_loss(adjusted_depth[mask], rendered_depth[mask], reduction='none')
loss_g = grad_loss(adjusted_depth, rendered_depth, mask)
loss = loss.mean() + 0.1*loss_g
# loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if i % 10 == 0:
history.append(loss.item())
scalar_deformation = torch.exp(deformation)
scalar_deformation = F.interpolate(scalar_deformation, size=pred_depth.shape[-2:], mode='bilinear', align_corners=True)
adjusted_depth = pred_depth * scalar_deformation
# return dict(aligned_depth=adjusted_depth.detach().cpu().numpy().squeeze(),
# history=history,
# deformation=deformation)
return adjusted_depth.detach().cpu().squeeze()
def stage_wise_optimization(rendered_depth, pred_depth, mask,
stages=[(4,4), (8,8), (16,16), (32,32)],
iters=100, init_lr=0.1, gamma=0.996, device='cuda:1'):
h_init, w_init = stages[0]
init_deformation = torch.zeros((1,1,h_init,w_init), device=device)
result = optimize_depth_deformation(rendered_depth, pred_depth, mask, h=h_init, w=w_init, iters=iters, init_lr=init_lr, gamma=gamma, init_deformation=init_deformation, device=device)
init_deformation = result['deformation']
history_stages = [result['history']]
for h, w in stages[1:]:
init_deformation = F.interpolate(init_deformation, size=(h,w), mode='bilinear', align_corners=True).detach()
result = optimize_depth_deformation(rendered_depth, pred_depth, mask, h=h, w=w, iters=iters, init_lr=init_lr, gamma=gamma, init_deformation=init_deformation, device=device)
init_deformation = result['deformation']
history_stages.append(result['history'])
init_lr *= gamma**2
return dict(aligned_depth=result['aligned_depth'], history_stages=history_stages)