File size: 2,862 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import torch

from vidar.utils.tensor import interpolate_image


def scale_output(pred, gt, scale_fn):
    """
    Match depth maps to ground-truth resolution

    Parameters
    ----------
    pred : torch.Tensor
        Predicted depth maps [B,1,w,h]
    gt : torch.tensor
        Ground-truth depth maps [B,1,H,W]
    scale_fn : String
        How to scale output to GT resolution
            Resize: Nearest neighbors interpolation
            top-center: Pad the top of the image and left-right corners with zeros

    Returns
    -------
    pred : torch.tensor
        Uncropped predicted depth maps [B,1,H,W]
    """
    if pred.dim() == 5 and gt.dim() == 5:
        return torch.stack([scale_output(pred[:, i], gt[:, i], scale_fn) for i in range(pred.shape[1])], 1)
    # Return depth map if scaling is not required
    if scale_fn == 'none':
        return pred
    elif scale_fn == 'resize':
        # Resize depth map to GT resolution
        return interpolate_image(pred, gt.shape, mode='bilinear', align_corners=True)
    else:
        # Create empty depth map with GT resolution
        pred_uncropped = torch.zeros(gt.shape, dtype=pred.dtype, device=pred.device)
        # Uncrop top vertically and center horizontally
        if scale_fn == 'top-center':
            top, left = gt.shape[2] - pred.shape[2], (gt.shape[3] - pred.shape[3]) // 2
            pred_uncropped[:, :, top:(top + pred.shape[2]), left:(left + pred.shape[3])] = pred
        else:
            raise NotImplementedError('Depth scale function {} not implemented.'.format(scale_fn))
        # Return uncropped depth map
        return pred_uncropped


def create_crop_mask(crop, gt):
    """
    Create crop mask for evaluation

    Parameters
    ----------
    crop : String
        Type of crop
    gt : torch.Tensor
        Ground-truth depth map (for dimensions)

    Returns
    -------
    crop_mask: torch.Tensor
        Mask for evaluation
    """
    # Return None if mask is not required
    if crop in ('', None):
        return None
    # Create empty mask
    batch_size, _, gt_height, gt_width = gt.shape
    crop_mask = torch.zeros(gt.shape[-2:]).byte().type_as(gt)
    # Get specific mask
    if crop == 'eigen_nyu':
        crop_mask[20:459, 24:615] = 1
    elif crop == 'bts_nyu':
        crop_mask[45:471, 41:601] = 1
    elif crop == 'garg':
        y1, y2 = int(0.40810811 * gt_height), int(0.99189189 * gt_height)
        x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
        crop_mask[y1:y2, x1:x2] = 1
    elif crop == 'eigen':
        y1, y2 = int(0.3324324 * gt_height), int(0.91351351 * gt_height)
        x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
        crop_mask[y1:y2, x1:x2] = 1
    # Return crop mask
    return crop_mask