Spaces:
Running
Running
import torch | |
import logging | |
import os | |
from mono.utils.avg_meter import MetricAverageMeter | |
from mono.utils.visualization import save_val_imgs, visual_train_data, create_html, save_raw_imgs, save_normal_val_imgs | |
import cv2 | |
from tqdm import tqdm | |
import numpy as np | |
from mono.utils.logger import setup_logger | |
from mono.utils.comm import main_process | |
#from scipy.optimize import minimize | |
#from torchmin import minimize | |
import torch.optim as optim | |
from torch.autograd import Variable | |
def to_cuda(data: dict): | |
for k, v in data.items(): | |
if isinstance(v, torch.Tensor): | |
data[k] = v.cuda(non_blocking=True) | |
if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor): | |
for i, l_i in enumerate(v): | |
data[k][i] = l_i.cuda(non_blocking=True) | |
return data | |
def align_scale(pred: torch.tensor, target: torch.tensor): | |
mask = target > 0 | |
if torch.sum(mask) > 10: | |
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) | |
else: | |
scale = 1 | |
pred_scale = pred * scale | |
return pred_scale, scale | |
def align_shift(pred: torch.tensor, target: torch.tensor): | |
mask = target > 0 | |
if torch.sum(mask) > 10: | |
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) | |
else: | |
shift = 0 | |
pred_shift = pred + shift | |
return pred_shift, shift | |
def align_scale_shift(pred: torch.tensor, target: torch.tensor): | |
mask = target > 0 | |
target_mask = target[mask].cpu().numpy() | |
pred_mask = pred[mask].cpu().numpy() | |
if torch.sum(mask) > 10: | |
scale, shift = np.polyfit(pred_mask, target_mask, deg=1) | |
if scale < 0: | |
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) | |
shift = 0 | |
else: | |
scale = 1 | |
shift = 0 | |
pred = pred * scale + shift | |
return pred, scale | |
def get_prediction( | |
model: torch.nn.Module, | |
input: torch.tensor, | |
cam_model: torch.tensor, | |
pad_info: torch.tensor, | |
scale_info: torch.tensor, | |
gt_depth: torch.tensor, | |
normalize_scale: float, | |
intrinsic = None, | |
clip_range = None, | |
flip_aug = False): | |
#clip_range = [0, 10], | |
#flip_aug = True): | |
data = dict( | |
input=input, | |
#ref_input=ref_input, | |
cam_model=cam_model | |
) | |
#output = model.module.inference(data) | |
output = model.module.inference(data) | |
pred_depth, confidence = output['prediction'], output['confidence'] | |
pred_depth = torch.abs(pred_depth) | |
pred_depth = pred_depth.squeeze() | |
if flip_aug == True: | |
output_flip = model.module.inference(dict( | |
input=torch.flip(input, [3]), | |
#ref_input=ref_input, | |
cam_model=cam_model | |
)) | |
if clip_range != None: | |
output['prediction'] = torch.clamp(output['prediction'], clip_range[0], clip_range[1]) | |
output_flip['prediction'] = torch.clamp(output_flip['prediction'], clip_range[0] / normalize_scale * scale_info , clip_range[1] / normalize_scale * scale_info) | |
output['prediction'] = 0.5 * (output['prediction'] + torch.flip(output_flip['prediction'], [3])) | |
output['confidence'] = 0.5 * (output['confidence'] + torch.flip(output_flip['confidence'], [3])) | |
output['pad'] = torch.Tensor(pad_info).cuda().unsqueeze(0).int() | |
output['mask'] = torch.ones_like(pred_depth).bool().unsqueeze(0).unsqueeze(1) | |
output['scale_info'] = scale_info | |
if intrinsic is not None: | |
output['intrinsic'] = intrinsic | |
pred_depth = pred_depth[pad_info[0]: pred_depth.shape[0]-pad_info[1], pad_info[2]: pred_depth.shape[1]-pad_info[3]] | |
pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], gt_depth.shape, mode='bilinear').squeeze() # to orginal size | |
pred_depth = pred_depth * normalize_scale / scale_info | |
if clip_range != None: | |
pred_depth = torch.clamp(pred_depth, clip_range[0], clip_range[1]) | |
pred_depth_scale, scale = align_scale(pred_depth, gt_depth) #align_scale_shift(pred_depth, gt_depth) | |
if clip_range != None: | |
pred_depth_scale = torch.clamp(pred_depth_scale, clip_range[0], clip_range[1]) | |
return pred_depth, pred_depth_scale, scale, output | |
# def depth_normal_consistency_optimization(output_dict, consistency_fn): | |
# s = torch.zeros_like(output_dict['scale_info']) | |
# def closure(x): | |
# output_dict['scale'] = torch.exp(x) * output_dict['scale_info'] | |
# error = consistency_fn(**output_dict) | |
# return error + x * x | |
# result = minimize(closure, s, method='newton-exact', disp=1, options={'max_iter':10, 'lr':0.1}) | |
# return float(torch.exp(-result.x)) | |
def do_test_with_dataloader( | |
model: torch.nn.Module, | |
cfg: dict, | |
dataloader: torch.utils.data, | |
logger: logging.RootLogger, | |
is_distributed: bool = True, | |
local_rank: int = 0): | |
show_dir = cfg.show_dir | |
save_interval = 100 | |
save_html_path = show_dir + '/index.html' | |
save_imgs_dir = show_dir + '/vis' | |
os.makedirs(save_imgs_dir, exist_ok=True) | |
save_raw_dir = show_dir + '/raw' | |
os.makedirs(save_raw_dir, exist_ok=True) | |
normalize_scale = cfg.data_basic.depth_range[1] | |
dam = MetricAverageMeter(cfg.test_metrics) | |
dam_scale = MetricAverageMeter(cfg.test_metrics) | |
try: | |
depth_range = cfg.data_basic.clip_depth_range if cfg.clip_depth else None | |
except: | |
depth_range = None | |
for i, data in enumerate(tqdm(dataloader)): | |
# logger.info(f'{local_rank}: {i}/{len(dataloader)}') | |
data = to_cuda(data) | |
gt_depth = data['target'].squeeze() | |
mask = gt_depth > 1e-6 | |
pad_info = data['pad'] | |
pred_depth, pred_depth_scale, scale, output = get_prediction( | |
model, | |
data['input'], | |
data['cam_model'], | |
pad_info, | |
data['scale'], | |
gt_depth, | |
normalize_scale, | |
data['intrinsic'], | |
) | |
logger.info(f'{data["filename"]}: {scale}') | |
# optimization | |
#if "normal_out_list" in output.keys(): | |
#scale_opt = depth_normal_consistency_optimization(output, consistency_loss) | |
#print('scale', scale_opt, float(scale)) | |
scale_opt = 1.0 | |
# update depth metrics | |
dam_scale.update_metrics_gpu(pred_depth_scale, gt_depth, mask, is_distributed) | |
dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed) | |
# save evaluation results | |
if i % save_interval == 0: | |
# save | |
rgb = data['input'][:, :, pad_info[0]: data['input'].shape[2]-pad_info[1], pad_info[2]: data['input'].shape[3]-pad_info[3]] | |
rgb = torch.nn.functional.interpolate(rgb, gt_depth.shape, mode='bilinear').squeeze() | |
max_scale = save_val_imgs(i, | |
pred_depth, | |
gt_depth, | |
rgb, | |
data['filename'][0], | |
save_imgs_dir, | |
) | |
logger.info(f'{data["filename"]}, {"max_scale"}: {max_scale}') | |
# # save original depth/rgb | |
# save_raw_imgs( | |
# pred_depth.cpu().squeeze().numpy(), | |
# data['raw_rgb'].cpu().squeeze().numpy(), | |
# data['filename'][0], | |
# save_raw_dir, | |
# ) | |
# surface normal metrics | |
if "normal_out_list" in output.keys(): | |
normal_out_list = output['normal_out_list'] | |
gt_normal = data['normal'] | |
pred_normal = normal_out_list[-1][:, :3, :, :] # (B, 3, H, W) | |
H, W = pred_normal.shape[2:] | |
pred_normal = pred_normal[:, :, pad_info[0]:H-pad_info[1], pad_info[2]:W-pad_info[3]] | |
pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True) | |
gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True) | |
dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal | |
if i % save_interval == 0: | |
save_normal_val_imgs(iter, | |
pred_normal, | |
gt_normal, | |
rgb, # data['input'], | |
'normal_' + data['filename'][0], | |
save_imgs_dir, | |
) | |
# get validation error | |
if main_process(): | |
eval_error = dam.get_metrics() | |
print('>>>>>W/o scale: ', eval_error) | |
eval_error_scale = dam_scale.get_metrics() | |
print('>>>>>W scale: ', eval_error_scale) | |
# disp_eval_error = dam_disp.get_metrics() | |
# print('>>>>>Disp to depth: ', disp_eval_error) | |
# for i, dam in enumerate(dams): | |
# print(f'>>>>>W/o scale gru{i}: ', dam.get_metrics()) | |
logger.info(eval_error) | |
logger.info(eval_error_scale) | |
# logger.info(disp_eval_error) | |
# [logger.info(dam.get_metrics()) for dam in dams] | |