File size: 9,158 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

from functools import reduce

import torch
import torch.nn.functional as tfn

from vidar.utils.decorators import iterate1
from vidar.utils.types import is_tensor, is_dict, is_seq


@iterate1
def interpolate(tensor, size, scale_factor, mode, align_corners):
    """
    Interpolate a tensor to a different resolution

    Parameters
    ----------
    tensor : torch.Tensor
        Input tensor [B,?,H,W]
    size : Tuple
        Interpolation size (H,W)
    scale_factor : Float
        Scale factor for interpolation
    mode : String
        Interpolation mode
    align_corners : Bool
        Corner alignment flag

    Returns
    -------
    tensor : torch.Tensor
        Interpolated tensor [B,?,h,w]
    """
    if is_tensor(size):
        size = size.shape[-2:]
    return tfn.interpolate(
        tensor, size=size, scale_factor=scale_factor,
        mode=mode, align_corners=align_corners, recompute_scale_factor=False,
    )


def masked_average(loss, mask, eps=1e-7):
    """Calculates the average of a tensor considering mask information"""
    return (loss * mask).sum() / (mask.sum() + eps)


def multiply_mask(data, mask):
    """Multiplies a tensor with a mask"""
    return data if (data is None or mask is None) else data * mask


def multiply_args(*args):
    """Multiplies input arguments"""
    valids = [v for v in args if v is not None]
    return None if not valids else reduce((lambda x, y: x * y), valids)


def grid_sample(tensor, grid, padding_mode, mode, align_corners):
    return tfn.grid_sample(tensor, grid,
        padding_mode=padding_mode, mode=mode, align_corners=align_corners)


def pixel_grid(hw, b=None, with_ones=False, device=None, normalize=False):
    """
    Creates a pixel grid for image operations

    Parameters
    ----------
    hw : Tuple
        Height/width of the grid
    b : Int
        Batch size
    with_ones : Bool
        Stack an extra channel with 1s
    device : String
        Device where the grid will be created
    normalize : Bool
        Whether the grid is normalized between [-1,1]

    Returns
    -------
    grid : torch.Tensor
        Output pixel grid [B,2,H,W]
    """
    if is_tensor(hw):
        b, hw = hw.shape[0], hw.shape[-2:]
    if is_tensor(device):
        device = device.device
    hi, hf = 0, hw[0] - 1
    wi, wf = 0, hw[1] - 1
    yy, xx = torch.meshgrid([torch.linspace(hi, hf, hw[0], device=device),
                             torch.linspace(wi, wf, hw[1], device=device)], indexing='ij')
    if with_ones:
        grid = torch.stack([xx, yy, torch.ones(hw, device=device)], 0)
    else:
        grid = torch.stack([xx, yy], 0)
    if b is not None:
        grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)
    if normalize:
        grid = norm_pixel_grid(grid)
    return grid


def norm_pixel_grid(grid, hw=None, in_place=False):
    """
    Normalize a pixel grid to be between [0,1]

    Parameters
    ----------
    grid : torch.Tensor
        Grid to be normalized [B,2,H,W]
    hw : Tuple
        Height/Width for normalization
    in_place : Bool
        Whether the operation is done in place or not

    Returns
    -------
    grid : torch.Tensor
        Normalized grid [B,2,H,W]
    """
    if hw is None:
        hw = grid.shape[-2:]
    if not in_place:
        grid = grid.clone()
    grid[:, 0] = 2.0 * grid[:, 0] / (hw[1] - 1) - 1.0
    grid[:, 1] = 2.0 * grid[:, 1] / (hw[0] - 1) - 1.0
    return grid


def unnorm_pixel_grid(grid, hw=None, in_place=False):
    """
    Unnormalize pixel grid to be between [0,H] and [0,W]

    Parameters
    ----------
    grid : torch.Tensor
        Grid to be normalized [B,2,H,W]
    hw : Tuple
        Height/width for unnormalization
    in_place : Bool
        Whether the operation is done in place or not

    Returns
    -------
    grid : torch.Tensor
        Unnormalized grid [B,2,H,W]
    """
    if hw is None:
        hw = grid.shape[-2:]
    if not in_place:
        grid = grid.clone()
    grid[:, 0] = 0.5 * (hw[1] - 1) * (grid[:, 0] + 1)
    grid[:, 1] = 0.5 * (hw[0] - 1) * (grid[:, 1] + 1)
    return grid


def match_scales(image, targets, num_scales,
                 mode='bilinear', align_corners=True):
    """
    Creates multiple resolution versions of the input to match another list of tensors

    Parameters
    ----------
    image : torch.Tensor
        Input image [B,?,H,W]
    targets : list[torch.Tensor]
        Target resolutions
    num_scales : int
        Number of scales to consider
    mode : String
        Interpolation mode
    align_corners : Bool
        Corner alignment flag

    Returns
    -------
    images : list[torch.Tensor]
        List containing tensors in the required resolutions
    """
    # For all scales
    images = []
    image_shape = image.shape[-2:]
    for i in range(num_scales):
        target_shape = targets[i].shape
        # If image shape is equal to target shape
        if same_shape(image_shape, target_shape):
            images.append(image)
        else:
            # Otherwise, interpolate
            images.append(interpolate_image(
                image, target_shape, mode=mode, align_corners=align_corners))
    # Return scaled images
    return images


def cat_channel_ones(tensor, n=1):
    """
    Concatenate tensor with an extra channel of ones

    Parameters
    ----------
    tensor : torch.Tensor
        Tensor to be concatenated
    n : Int
        Which channel will be concatenated

    Returns
    -------
    cat_tensor : torch.Tensor
        Concatenated tensor
    """
    # Get tensor shape with 1 channel
    shape = list(tensor.shape)
    shape[n] = 1
    # Return concatenation of tensor with ones
    return torch.cat([tensor, torch.ones(shape,
                      device=tensor.device, dtype=tensor.dtype)], n)


def same_shape(shape1, shape2):
    """Checks if two shapes are the same"""
    if len(shape1) != len(shape2):
        return False
    for i in range(len(shape1)):
        if shape1[i] != shape2[i]:
            return False
    return True


def interpolate_image(image, shape=None, scale_factor=None, mode='bilinear',
                      align_corners=True, recompute_scale_factor=False):
    """
    Interpolate an image to a different resolution

    Parameters
    ----------
    image : torch.Tensor
        Image to be interpolated [B,?,h,w]
    shape : torch.Tensor or tuple
        Output shape [H,W]
    scale_factor : Float
        Scale factor for output shape
    mode : String
        Interpolation mode
    align_corners : Bool
        True if corners will be aligned after interpolation
    recompute_scale_factor : Bool
        True if scale factor is recomputed

    Returns
    -------
    image : torch.Tensor
        Interpolated image [B,?,H,W]
    """
    assert shape is not None or scale_factor is not None, 'Invalid option for interpolate_image'
    if mode == 'nearest':
        align_corners = None
    # Take last two dimensions as shape
    if shape is not None:
        if is_tensor(shape):
            shape = shape.shape
        if len(shape) > 2:
            shape = shape[-2:]
        # If the shapes are the same, do nothing
        if same_shape(image.shape[-2:], shape):
            return image
    # Interpolate image to match the shape
    return tfn.interpolate(image, size=shape, scale_factor=scale_factor,
                           mode=mode, align_corners=align_corners,
                           recompute_scale_factor=recompute_scale_factor)


def check_assert(pred, gt, atol=1e-5, rtol=1e-5):
    """
    Check two dictionaries with allclose assertions

    Parameters
    ----------
    pred : Dict
        Dictionary with predictions
    gt : Dict
        Dictionary with ground-truth
    atol : Float
        Absolute tolerance
    rtol : Float
        Relative tolerance
    """
    for key in gt.keys():
        if key in pred.keys():
            # assert key in pred and key in gt
            if is_dict(pred[key]):
                check_assert(pred[key], gt[key])
            elif is_seq(pred[key]):
                for val1, val2 in zip(pred[key], gt[key]):
                    if is_tensor(val1):
                        assert torch.allclose(val1, val2, atol=atol, rtol=rtol), \
                            f'Assert error in {key} : {val1.mean().item()} x {val2.mean().item()}'
                    else:
                        assert val1 == val2, \
                            f'Assert error in {key} : {val1} x {val2}'
            else:
                if is_tensor(pred[key]):
                    assert torch.allclose(pred[key], gt[key], atol=atol, rtol=rtol), \
                        f'Assert error in {key} : {pred[key].mean().item()} x {gt[key].mean().item()}'
                else:
                    assert pred[key] == gt[key], \
                        f'Assert error in {key} : {pred[key]} x {gt[key]}'


def interleave(data, b):
    """Interleave data considering multiple batches"""
    data_interleave = data.unsqueeze(1).expand(-1, b, *data.shape[1:])
    return data_interleave.reshape(-1, *data.shape[1:])