zach
initial commit based on github repo
3ef1661
import torch
import logging
import os
from mono.utils.avg_meter import MetricAverageMeter
from mono.utils.visualization import save_val_imgs, visual_train_data, create_html, save_raw_imgs, save_normal_val_imgs
import cv2
from tqdm import tqdm
import numpy as np
from mono.utils.logger import setup_logger
from mono.utils.comm import main_process
#from scipy.optimize import minimize
#from torchmin import minimize
import torch.optim as optim
from torch.autograd import Variable
def to_cuda(data: dict):
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.cuda(non_blocking=True)
if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor):
for i, l_i in enumerate(v):
data[k][i] = l_i.cuda(non_blocking=True)
return data
def align_scale(pred: torch.tensor, target: torch.tensor):
mask = target > 0
if torch.sum(mask) > 10:
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
else:
scale = 1
pred_scale = pred * scale
return pred_scale, scale
def align_shift(pred: torch.tensor, target: torch.tensor):
mask = target > 0
if torch.sum(mask) > 10:
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8)
else:
shift = 0
pred_shift = pred + shift
return pred_shift, shift
def align_scale_shift(pred: torch.tensor, target: torch.tensor):
mask = target > 0
target_mask = target[mask].cpu().numpy()
pred_mask = pred[mask].cpu().numpy()
if torch.sum(mask) > 10:
scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
if scale < 0:
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
shift = 0
else:
scale = 1
shift = 0
pred = pred * scale + shift
return pred, scale
def get_prediction(
model: torch.nn.Module,
input: torch.tensor,
cam_model: torch.tensor,
pad_info: torch.tensor,
scale_info: torch.tensor,
gt_depth: torch.tensor,
normalize_scale: float,
intrinsic = None,
clip_range = None,
flip_aug = False):
#clip_range = [0, 10],
#flip_aug = True):
data = dict(
input=input,
#ref_input=ref_input,
cam_model=cam_model
)
#output = model.module.inference(data)
output = model.module.inference(data)
pred_depth, confidence = output['prediction'], output['confidence']
pred_depth = torch.abs(pred_depth)
pred_depth = pred_depth.squeeze()
if flip_aug == True:
output_flip = model.module.inference(dict(
input=torch.flip(input, [3]),
#ref_input=ref_input,
cam_model=cam_model
))
if clip_range != None:
output['prediction'] = torch.clamp(output['prediction'], clip_range[0], clip_range[1])
output_flip['prediction'] = torch.clamp(output_flip['prediction'], clip_range[0] / normalize_scale * scale_info , clip_range[1] / normalize_scale * scale_info)
output['prediction'] = 0.5 * (output['prediction'] + torch.flip(output_flip['prediction'], [3]))
output['confidence'] = 0.5 * (output['confidence'] + torch.flip(output_flip['confidence'], [3]))
output['pad'] = torch.Tensor(pad_info).cuda().unsqueeze(0).int()
output['mask'] = torch.ones_like(pred_depth).bool().unsqueeze(0).unsqueeze(1)
output['scale_info'] = scale_info
if intrinsic is not None:
output['intrinsic'] = intrinsic
pred_depth = pred_depth[pad_info[0]: pred_depth.shape[0]-pad_info[1], pad_info[2]: pred_depth.shape[1]-pad_info[3]]
pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], gt_depth.shape, mode='bilinear').squeeze() # to orginal size
pred_depth = pred_depth * normalize_scale / scale_info
if clip_range != None:
pred_depth = torch.clamp(pred_depth, clip_range[0], clip_range[1])
pred_depth_scale, scale = align_scale(pred_depth, gt_depth) #align_scale_shift(pred_depth, gt_depth)
if clip_range != None:
pred_depth_scale = torch.clamp(pred_depth_scale, clip_range[0], clip_range[1])
return pred_depth, pred_depth_scale, scale, output
# def depth_normal_consistency_optimization(output_dict, consistency_fn):
# s = torch.zeros_like(output_dict['scale_info'])
# def closure(x):
# output_dict['scale'] = torch.exp(x) * output_dict['scale_info']
# error = consistency_fn(**output_dict)
# return error + x * x
# result = minimize(closure, s, method='newton-exact', disp=1, options={'max_iter':10, 'lr':0.1})
# return float(torch.exp(-result.x))
def do_test_with_dataloader(
model: torch.nn.Module,
cfg: dict,
dataloader: torch.utils.data,
logger: logging.RootLogger,
is_distributed: bool = True,
local_rank: int = 0):
show_dir = cfg.show_dir
save_interval = 100
save_html_path = show_dir + '/index.html'
save_imgs_dir = show_dir + '/vis'
os.makedirs(save_imgs_dir, exist_ok=True)
save_raw_dir = show_dir + '/raw'
os.makedirs(save_raw_dir, exist_ok=True)
normalize_scale = cfg.data_basic.depth_range[1]
dam = MetricAverageMeter(cfg.test_metrics)
dam_scale = MetricAverageMeter(cfg.test_metrics)
try:
depth_range = cfg.data_basic.clip_depth_range if cfg.clip_depth else None
except:
depth_range = None
for i, data in enumerate(tqdm(dataloader)):
# logger.info(f'{local_rank}: {i}/{len(dataloader)}')
data = to_cuda(data)
gt_depth = data['target'].squeeze()
mask = gt_depth > 1e-6
pad_info = data['pad']
pred_depth, pred_depth_scale, scale, output = get_prediction(
model,
data['input'],
data['cam_model'],
pad_info,
data['scale'],
gt_depth,
normalize_scale,
data['intrinsic'],
)
logger.info(f'{data["filename"]}: {scale}')
# optimization
#if "normal_out_list" in output.keys():
#scale_opt = depth_normal_consistency_optimization(output, consistency_loss)
#print('scale', scale_opt, float(scale))
scale_opt = 1.0
# update depth metrics
dam_scale.update_metrics_gpu(pred_depth_scale, gt_depth, mask, is_distributed)
dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed)
# save evaluation results
if i % save_interval == 0:
# save
rgb = data['input'][:, :, pad_info[0]: data['input'].shape[2]-pad_info[1], pad_info[2]: data['input'].shape[3]-pad_info[3]]
rgb = torch.nn.functional.interpolate(rgb, gt_depth.shape, mode='bilinear').squeeze()
max_scale = save_val_imgs(i,
pred_depth,
gt_depth,
rgb,
data['filename'][0],
save_imgs_dir,
)
logger.info(f'{data["filename"]}, {"max_scale"}: {max_scale}')
# # save original depth/rgb
# save_raw_imgs(
# pred_depth.cpu().squeeze().numpy(),
# data['raw_rgb'].cpu().squeeze().numpy(),
# data['filename'][0],
# save_raw_dir,
# )
# surface normal metrics
if "normal_out_list" in output.keys():
normal_out_list = output['normal_out_list']
gt_normal = data['normal']
pred_normal = normal_out_list[-1][:, :3, :, :] # (B, 3, H, W)
H, W = pred_normal.shape[2:]
pred_normal = pred_normal[:, :, pad_info[0]:H-pad_info[1], pad_info[2]:W-pad_info[3]]
pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True)
gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True)
dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal
if i % save_interval == 0:
save_normal_val_imgs(iter,
pred_normal,
gt_normal,
rgb, # data['input'],
'normal_' + data['filename'][0],
save_imgs_dir,
)
# get validation error
if main_process():
eval_error = dam.get_metrics()
print('>>>>>W/o scale: ', eval_error)
eval_error_scale = dam_scale.get_metrics()
print('>>>>>W scale: ', eval_error_scale)
# disp_eval_error = dam_disp.get_metrics()
# print('>>>>>Disp to depth: ', disp_eval_error)
# for i, dam in enumerate(dams):
# print(f'>>>>>W/o scale gru{i}: ', dam.get_metrics())
logger.info(eval_error)
logger.info(eval_error_scale)
# logger.info(disp_eval_error)
# [logger.info(dam.get_metrics()) for dam in dams]