Spaces:
Runtime error
Runtime error
| # TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
| import torch | |
| from vidar.utils.tensor import interpolate_image | |
| def scale_output(pred, gt, scale_fn): | |
| """ | |
| Match depth maps to ground-truth resolution | |
| Parameters | |
| ---------- | |
| pred : torch.Tensor | |
| Predicted depth maps [B,1,w,h] | |
| gt : torch.tensor | |
| Ground-truth depth maps [B,1,H,W] | |
| scale_fn : String | |
| How to scale output to GT resolution | |
| Resize: Nearest neighbors interpolation | |
| top-center: Pad the top of the image and left-right corners with zeros | |
| Returns | |
| ------- | |
| pred : torch.tensor | |
| Uncropped predicted depth maps [B,1,H,W] | |
| """ | |
| if pred.dim() == 5 and gt.dim() == 5: | |
| return torch.stack([scale_output(pred[:, i], gt[:, i], scale_fn) for i in range(pred.shape[1])], 1) | |
| # Return depth map if scaling is not required | |
| if scale_fn == 'none': | |
| return pred | |
| elif scale_fn == 'resize': | |
| # Resize depth map to GT resolution | |
| return interpolate_image(pred, gt.shape, mode='bilinear', align_corners=True) | |
| else: | |
| # Create empty depth map with GT resolution | |
| pred_uncropped = torch.zeros(gt.shape, dtype=pred.dtype, device=pred.device) | |
| # Uncrop top vertically and center horizontally | |
| if scale_fn == 'top-center': | |
| top, left = gt.shape[2] - pred.shape[2], (gt.shape[3] - pred.shape[3]) // 2 | |
| pred_uncropped[:, :, top:(top + pred.shape[2]), left:(left + pred.shape[3])] = pred | |
| else: | |
| raise NotImplementedError('Depth scale function {} not implemented.'.format(scale_fn)) | |
| # Return uncropped depth map | |
| return pred_uncropped | |
| def create_crop_mask(crop, gt): | |
| """ | |
| Create crop mask for evaluation | |
| Parameters | |
| ---------- | |
| crop : String | |
| Type of crop | |
| gt : torch.Tensor | |
| Ground-truth depth map (for dimensions) | |
| Returns | |
| ------- | |
| crop_mask: torch.Tensor | |
| Mask for evaluation | |
| """ | |
| # Return None if mask is not required | |
| if crop in ('', None): | |
| return None | |
| # Create empty mask | |
| batch_size, _, gt_height, gt_width = gt.shape | |
| crop_mask = torch.zeros(gt.shape[-2:]).byte().type_as(gt) | |
| # Get specific mask | |
| if crop == 'eigen_nyu': | |
| crop_mask[20:459, 24:615] = 1 | |
| elif crop == 'bts_nyu': | |
| crop_mask[45:471, 41:601] = 1 | |
| elif crop == 'garg': | |
| y1, y2 = int(0.40810811 * gt_height), int(0.99189189 * gt_height) | |
| x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width) | |
| crop_mask[y1:y2, x1:x2] = 1 | |
| elif crop == 'eigen': | |
| y1, y2 = int(0.3324324 * gt_height), int(0.91351351 * gt_height) | |
| x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width) | |
| crop_mask[y1:y2, x1:x2] = 1 | |
| # Return crop mask | |
| return crop_mask | |