|
import torch |
|
import logging |
|
import os |
|
from mono.utils.avg_meter import MetricAverageMeter |
|
from mono.utils.visualization import save_val_imgs, visual_train_data, create_html, save_raw_imgs, save_normal_val_imgs |
|
import cv2 |
|
from tqdm import tqdm |
|
import numpy as np |
|
from mono.utils.logger import setup_logger |
|
from mono.utils.comm import main_process |
|
|
|
|
|
import torch.optim as optim |
|
from torch.autograd import Variable |
|
|
|
|
|
def to_cuda(data: dict): |
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor): |
|
data[k] = v.cuda(non_blocking=True) |
|
if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor): |
|
for i, l_i in enumerate(v): |
|
data[k][i] = l_i.cuda(non_blocking=True) |
|
return data |
|
|
|
def align_scale(pred: torch.tensor, target: torch.tensor): |
|
mask = target > 0 |
|
if torch.sum(mask) > 10: |
|
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) |
|
else: |
|
scale = 1 |
|
pred_scale = pred * scale |
|
return pred_scale, scale |
|
|
|
def align_shift(pred: torch.tensor, target: torch.tensor): |
|
mask = target > 0 |
|
if torch.sum(mask) > 10: |
|
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) |
|
else: |
|
shift = 0 |
|
pred_shift = pred + shift |
|
return pred_shift, shift |
|
|
|
def align_scale_shift(pred: torch.tensor, target: torch.tensor): |
|
mask = target > 0 |
|
target_mask = target[mask].cpu().numpy() |
|
pred_mask = pred[mask].cpu().numpy() |
|
if torch.sum(mask) > 10: |
|
scale, shift = np.polyfit(pred_mask, target_mask, deg=1) |
|
if scale < 0: |
|
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) |
|
shift = 0 |
|
else: |
|
scale = 1 |
|
shift = 0 |
|
pred = pred * scale + shift |
|
return pred, scale |
|
|
|
def get_prediction( |
|
model: torch.nn.Module, |
|
input: torch.tensor, |
|
cam_model: torch.tensor, |
|
pad_info: torch.tensor, |
|
scale_info: torch.tensor, |
|
gt_depth: torch.tensor, |
|
normalize_scale: float, |
|
intrinsic = None, |
|
clip_range = None, |
|
flip_aug = False): |
|
|
|
|
|
|
|
data = dict( |
|
input=input, |
|
|
|
cam_model=cam_model |
|
) |
|
|
|
output = model.module.inference(data) |
|
pred_depth, confidence = output['prediction'], output['confidence'] |
|
pred_depth = torch.abs(pred_depth) |
|
pred_depth = pred_depth.squeeze() |
|
|
|
if flip_aug == True: |
|
output_flip = model.module.inference(dict( |
|
input=torch.flip(input, [3]), |
|
|
|
cam_model=cam_model |
|
)) |
|
|
|
if clip_range != None: |
|
output['prediction'] = torch.clamp(output['prediction'], clip_range[0], clip_range[1]) |
|
output_flip['prediction'] = torch.clamp(output_flip['prediction'], clip_range[0] / normalize_scale * scale_info , clip_range[1] / normalize_scale * scale_info) |
|
|
|
output['prediction'] = 0.5 * (output['prediction'] + torch.flip(output_flip['prediction'], [3])) |
|
output['confidence'] = 0.5 * (output['confidence'] + torch.flip(output_flip['confidence'], [3])) |
|
|
|
output['pad'] = torch.Tensor(pad_info).cuda().unsqueeze(0).int() |
|
output['mask'] = torch.ones_like(pred_depth).bool().unsqueeze(0).unsqueeze(1) |
|
output['scale_info'] = scale_info |
|
if intrinsic is not None: |
|
output['intrinsic'] = intrinsic |
|
|
|
pred_depth = pred_depth[pad_info[0]: pred_depth.shape[0]-pad_info[1], pad_info[2]: pred_depth.shape[1]-pad_info[3]] |
|
pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], gt_depth.shape, mode='bilinear').squeeze() |
|
pred_depth = pred_depth * normalize_scale / scale_info |
|
|
|
if clip_range != None: |
|
pred_depth = torch.clamp(pred_depth, clip_range[0], clip_range[1]) |
|
|
|
pred_depth_scale, scale = align_scale(pred_depth, gt_depth) |
|
|
|
if clip_range != None: |
|
pred_depth_scale = torch.clamp(pred_depth_scale, clip_range[0], clip_range[1]) |
|
|
|
return pred_depth, pred_depth_scale, scale, output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def do_test_with_dataloader( |
|
model: torch.nn.Module, |
|
cfg: dict, |
|
dataloader: torch.utils.data, |
|
logger: logging.RootLogger, |
|
is_distributed: bool = True, |
|
local_rank: int = 0): |
|
|
|
show_dir = cfg.show_dir |
|
save_interval = 100 |
|
save_html_path = show_dir + '/index.html' |
|
save_imgs_dir = show_dir + '/vis' |
|
os.makedirs(save_imgs_dir, exist_ok=True) |
|
save_raw_dir = show_dir + '/raw' |
|
os.makedirs(save_raw_dir, exist_ok=True) |
|
|
|
normalize_scale = cfg.data_basic.depth_range[1] |
|
|
|
dam = MetricAverageMeter(cfg.test_metrics) |
|
dam_scale = MetricAverageMeter(cfg.test_metrics) |
|
|
|
try: |
|
depth_range = cfg.data_basic.clip_depth_range if cfg.clip_depth else None |
|
except: |
|
depth_range = None |
|
|
|
for i, data in enumerate(tqdm(dataloader)): |
|
|
|
|
|
data = to_cuda(data) |
|
gt_depth = data['target'].squeeze() |
|
mask = gt_depth > 1e-6 |
|
pad_info = data['pad'] |
|
pred_depth, pred_depth_scale, scale, output = get_prediction( |
|
model, |
|
data['input'], |
|
data['cam_model'], |
|
pad_info, |
|
data['scale'], |
|
gt_depth, |
|
normalize_scale, |
|
data['intrinsic'], |
|
) |
|
|
|
logger.info(f'{data["filename"]}: {scale}') |
|
|
|
|
|
|
|
|
|
|
|
scale_opt = 1.0 |
|
|
|
|
|
dam_scale.update_metrics_gpu(pred_depth_scale, gt_depth, mask, is_distributed) |
|
dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed) |
|
|
|
|
|
if i % save_interval == 0: |
|
|
|
rgb = data['input'][:, :, pad_info[0]: data['input'].shape[2]-pad_info[1], pad_info[2]: data['input'].shape[3]-pad_info[3]] |
|
rgb = torch.nn.functional.interpolate(rgb, gt_depth.shape, mode='bilinear').squeeze() |
|
max_scale = save_val_imgs(i, |
|
pred_depth, |
|
gt_depth, |
|
rgb, |
|
data['filename'][0], |
|
save_imgs_dir, |
|
) |
|
logger.info(f'{data["filename"]}, {"max_scale"}: {max_scale}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "normal_out_list" in output.keys(): |
|
normal_out_list = output['normal_out_list'] |
|
gt_normal = data['normal'] |
|
|
|
pred_normal = normal_out_list[-1][:, :3, :, :] |
|
H, W = pred_normal.shape[2:] |
|
pred_normal = pred_normal[:, :, pad_info[0]:H-pad_info[1], pad_info[2]:W-pad_info[3]] |
|
pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True) |
|
|
|
gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True) |
|
dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed) |
|
|
|
if i % save_interval == 0: |
|
save_normal_val_imgs(iter, |
|
pred_normal, |
|
gt_normal, |
|
rgb, |
|
'normal_' + data['filename'][0], |
|
save_imgs_dir, |
|
) |
|
|
|
|
|
if main_process(): |
|
eval_error = dam.get_metrics() |
|
print('>>>>>W/o scale: ', eval_error) |
|
eval_error_scale = dam_scale.get_metrics() |
|
print('>>>>>W scale: ', eval_error_scale) |
|
|
|
|
|
|
|
|
|
|
|
logger.info(eval_error) |
|
logger.info(eval_error_scale) |
|
|
|
|
|
|