Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from vidar.utils.tensor import interpolate_image
def scale_output(pred, gt, scale_fn):
"""
Match depth maps to ground-truth resolution
Parameters
----------
pred : torch.Tensor
Predicted depth maps [B,1,w,h]
gt : torch.tensor
Ground-truth depth maps [B,1,H,W]
scale_fn : String
How to scale output to GT resolution
Resize: Nearest neighbors interpolation
top-center: Pad the top of the image and left-right corners with zeros
Returns
-------
pred : torch.tensor
Uncropped predicted depth maps [B,1,H,W]
"""
if pred.dim() == 5 and gt.dim() == 5:
return torch.stack([scale_output(pred[:, i], gt[:, i], scale_fn) for i in range(pred.shape[1])], 1)
# Return depth map if scaling is not required
if scale_fn == 'none':
return pred
elif scale_fn == 'resize':
# Resize depth map to GT resolution
return interpolate_image(pred, gt.shape, mode='bilinear', align_corners=True)
else:
# Create empty depth map with GT resolution
pred_uncropped = torch.zeros(gt.shape, dtype=pred.dtype, device=pred.device)
# Uncrop top vertically and center horizontally
if scale_fn == 'top-center':
top, left = gt.shape[2] - pred.shape[2], (gt.shape[3] - pred.shape[3]) // 2
pred_uncropped[:, :, top:(top + pred.shape[2]), left:(left + pred.shape[3])] = pred
else:
raise NotImplementedError('Depth scale function {} not implemented.'.format(scale_fn))
# Return uncropped depth map
return pred_uncropped
def create_crop_mask(crop, gt):
"""
Create crop mask for evaluation
Parameters
----------
crop : String
Type of crop
gt : torch.Tensor
Ground-truth depth map (for dimensions)
Returns
-------
crop_mask: torch.Tensor
Mask for evaluation
"""
# Return None if mask is not required
if crop in ('', None):
return None
# Create empty mask
batch_size, _, gt_height, gt_width = gt.shape
crop_mask = torch.zeros(gt.shape[-2:]).byte().type_as(gt)
# Get specific mask
if crop == 'eigen_nyu':
crop_mask[20:459, 24:615] = 1
elif crop == 'bts_nyu':
crop_mask[45:471, 41:601] = 1
elif crop == 'garg':
y1, y2 = int(0.40810811 * gt_height), int(0.99189189 * gt_height)
x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
crop_mask[y1:y2, x1:x2] = 1
elif crop == 'eigen':
y1, y2 = int(0.3324324 * gt_height), int(0.91351351 * gt_height)
x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
crop_mask[y1:y2, x1:x2] = 1
# Return crop mask
return crop_mask