File size: 7,584 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import random
from collections import OrderedDict
from inspect import signature

import numpy as np
import torch

from vidar.utils.decorators import iterate1, iterate2
from vidar.utils.types import is_list, is_double_list, is_tuple, is_tensor, is_dict, is_seq

KEYS_IMAGE = [
        'rgb', 'mask',
        'input_depth', 'depth',
        'bwd_optical_flow', 'fwd_optical_flow',
    ]

KEYS_MATRIX = [
        'intrinsics', 'extrinsics', 'pose', 'semantic',
    ]


def modrem(v, n):
    """Return round division and remainder"""
    return v // n, v % n


def flatten(lst):
    """Flatten a list of lists into a list"""
    return [l for ls in lst for l in ls] if is_double_list(lst) else lst


def keys_with(dic, string, without=()):
    """Return keys from a dictionary that contain a certain string"""
    return [key for key in dic if string in key and not any(w in key for w in make_list(without))]


def keys_startswith(dic, string):
    """Return keys from a dictionary that contain a certain string"""
    return [key for key in dic if key.startswith(string)]


def keys_in(dic, keys):
    """Return only keys in a dictionary"""
    return [key for key in keys if key in dic]


def str_not_in(string, keys):
    for key in keys:
        if key in string:
            return False
    return True


def make_list(var, n=None):
    """Wraps the input into a list, and optionally repeats it to be size n"""
    var = var if is_list(var) or is_tuple(var) else [var]
    if n is None:
        return var
    else:
        assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list'
        return var * n if len(var) == 1 else var


def filter_args(func, keys):
    """Filters a dictionary, so it only contains keys that are arguments of a function"""
    filtered = {}
    sign = list(signature(func).parameters.keys())
    for k, v in {**keys}.items():
        if k in sign:
            filtered[k] = v
    return filtered


def dict_remove_nones(dic):
    """Filters dictionary to remove keys with None values"""
    return {key: val for key, val in dic.items() if val is not None}


@iterate1
def matmul1(v1, v2):
    """Iteratively multiply tensors"""
    return v1 @ v2


@iterate2
def matmul2(v1, v2):
    """Iteratively multiply tensors"""
    return v1 @ v2


@iterate1
def unsqueeze(x):
    """Iteratively unsqueeze tensors to batch size 1"""
    return x.unsqueeze(0) if is_tensor(x) else x


@iterate1
def fold(data, n):
    """Iteratively folds first and second dimensions into one"""
    shape = list(data.shape)
    if len(shape) == n + 1:
        shape = [shape[0] * shape[1]] + shape[2:]
        return data.view(*shape)
    else:
        return data


@iterate1
def expand(data, n, d):
    """Iteratively folds first and second dimensions into one"""
    shape = list(data.shape)
    if len(shape) == n:
        return data.unsqueeze(d)
    else:
        return data


def fold_batch(batch, device=None):
    """Combine the first (batch) and second (camera) dimensions of a batch"""
    if is_seq(batch):
        return [fold_batch(b, device=device) for b in batch]
    for key in keys_in(batch, KEYS_IMAGE):
        batch[key] = fold(batch[key], 4)
    for key in keys_in(batch, KEYS_MATRIX):
        batch[key] = fold(batch[key], 3)
    if device is not None:
        batch = batch_to_device(batch, device)
    return batch


def expand_batch(batch, d, device=None):
    """Expand the batch to include an additional dimension (0 for batch, 1 for camera)"""
    if is_seq(batch):
        return [expand_batch(b, d, device=device) for b in batch]
    d = {'batch': 0, 'camera': 1}[d]
    for key in keys_in(batch, KEYS_IMAGE):
        batch[key] = expand(batch[key], 4, d)
    for key in keys_in(batch, KEYS_MATRIX):
        batch[key] = expand(batch[key], 3, d)
    if device is not None:
        batch = batch_to_device(batch, device)
    return batch


def batch_to_device(batch, device):
    """Copy batch information to device"""
    if is_dict(batch):
        return {key: batch_to_device(val, device) for key, val in batch.items()}
    if is_list(batch):
        return [batch_to_device(val, device) for val in batch]
    if is_tensor(batch):
        return batch.to(device)
    return batch


def num_trainable_params(model):
    """Return number of trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def set_random_seed(seed):
    if seed >= 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def make_batch(batch, device=None):
    """Transforms a sample into a batch"""
    for key in batch.keys():
        if is_dict(batch[key]):
            batch[key] = make_batch(batch[key])
        elif is_tensor(batch[key]):
            batch[key] = batch[key].unsqueeze(0)
    if device is not None:
        batch = batch_to_device(batch, device)
    return batch


def break_key(sample, n=None):
    """Break a multi-camera sample key, so different cameras have their own entries (context, camera)"""
    if sample is None:
        return sample
    new_sample = OrderedDict()
    for ctx in sample.keys():
        if is_dict(sample[ctx]):
            for key2, val in sample[ctx].items():
                if val.dim() == 1:
                    val = val.unsqueeze(1)
                for i in range(val.shape[1]):
                    if (ctx, i) not in new_sample.keys():
                        new_sample[(ctx, i)] = {}
                    new_sample[(ctx, i)][key2] = val[:, [i]]
        elif sample[ctx].dim() == n + 1:
            for i in range(sample[ctx].shape[1]):
                new_sample[(ctx, i)] = sample[ctx][:, i]
    return new_sample


def break_batch(batch):
    """Break a multi-camera batch, so different cameras have their own entries (context, camera)"""
    for key in keys_in(batch, KEYS_IMAGE):
        for ctx in list(batch[key].keys()):
            if batch[key][ctx].dim() == 5:
                for n in range(batch[key][ctx].shape[1]):
                    batch[key][(ctx,n)] = batch[key][ctx][:, n]
                batch[key].pop(ctx)
    for key in keys_in(batch, KEYS_MATRIX):
        for ctx in list(batch[key].keys()):
            if batch[key][ctx].dim() == 4:
                for n in range(batch[key][ctx].shape[1]):
                    batch[key][(ctx,n)] = batch[key][ctx][:, n]
                batch[key].pop(ctx)
    return batch


def dict_has(dic, key):
    """Check if a dictionary has a certain key"""
    return key in dic


def get_from_dict(dic, key):
    """Get value from a dictionary (return None if key is not in dictionary)"""
    return None if key not in dic else dic[key]


def get_mask_from_list(mask, i, return_ones=None):
    """Retrieve mask from a list (if it's not a list, return the mask itself, and create one if requested)"""
    if return_ones is None:
        return None if mask is None else mask[i] if is_list(mask) else mask
    else:
        mask = torch.ones_like(return_ones[i] if is_list(return_ones) else return_ones).bool() if mask is None \
            else mask[i].clone().bool() if is_list(mask) else mask.clone().bool()
        if mask.dim() == 4:
            return mask[:, [0]]
        elif mask.dim() == 3:
            return mask[..., [0]]


def get_from_list(lst, i):
    """Get information from a list (return None if input is None, and return value directly if it's not a list)"""
    return None if lst is None else lst[i] if is_seq(lst) else lst