Spaces:
Runtime error
Runtime error
File size: 9,158 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from functools import reduce
import torch
import torch.nn.functional as tfn
from vidar.utils.decorators import iterate1
from vidar.utils.types import is_tensor, is_dict, is_seq
@iterate1
def interpolate(tensor, size, scale_factor, mode, align_corners):
"""
Interpolate a tensor to a different resolution
Parameters
----------
tensor : torch.Tensor
Input tensor [B,?,H,W]
size : Tuple
Interpolation size (H,W)
scale_factor : Float
Scale factor for interpolation
mode : String
Interpolation mode
align_corners : Bool
Corner alignment flag
Returns
-------
tensor : torch.Tensor
Interpolated tensor [B,?,h,w]
"""
if is_tensor(size):
size = size.shape[-2:]
return tfn.interpolate(
tensor, size=size, scale_factor=scale_factor,
mode=mode, align_corners=align_corners, recompute_scale_factor=False,
)
def masked_average(loss, mask, eps=1e-7):
"""Calculates the average of a tensor considering mask information"""
return (loss * mask).sum() / (mask.sum() + eps)
def multiply_mask(data, mask):
"""Multiplies a tensor with a mask"""
return data if (data is None or mask is None) else data * mask
def multiply_args(*args):
"""Multiplies input arguments"""
valids = [v for v in args if v is not None]
return None if not valids else reduce((lambda x, y: x * y), valids)
def grid_sample(tensor, grid, padding_mode, mode, align_corners):
return tfn.grid_sample(tensor, grid,
padding_mode=padding_mode, mode=mode, align_corners=align_corners)
def pixel_grid(hw, b=None, with_ones=False, device=None, normalize=False):
"""
Creates a pixel grid for image operations
Parameters
----------
hw : Tuple
Height/width of the grid
b : Int
Batch size
with_ones : Bool
Stack an extra channel with 1s
device : String
Device where the grid will be created
normalize : Bool
Whether the grid is normalized between [-1,1]
Returns
-------
grid : torch.Tensor
Output pixel grid [B,2,H,W]
"""
if is_tensor(hw):
b, hw = hw.shape[0], hw.shape[-2:]
if is_tensor(device):
device = device.device
hi, hf = 0, hw[0] - 1
wi, wf = 0, hw[1] - 1
yy, xx = torch.meshgrid([torch.linspace(hi, hf, hw[0], device=device),
torch.linspace(wi, wf, hw[1], device=device)], indexing='ij')
if with_ones:
grid = torch.stack([xx, yy, torch.ones(hw, device=device)], 0)
else:
grid = torch.stack([xx, yy], 0)
if b is not None:
grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)
if normalize:
grid = norm_pixel_grid(grid)
return grid
def norm_pixel_grid(grid, hw=None, in_place=False):
"""
Normalize a pixel grid to be between [0,1]
Parameters
----------
grid : torch.Tensor
Grid to be normalized [B,2,H,W]
hw : Tuple
Height/Width for normalization
in_place : Bool
Whether the operation is done in place or not
Returns
-------
grid : torch.Tensor
Normalized grid [B,2,H,W]
"""
if hw is None:
hw = grid.shape[-2:]
if not in_place:
grid = grid.clone()
grid[:, 0] = 2.0 * grid[:, 0] / (hw[1] - 1) - 1.0
grid[:, 1] = 2.0 * grid[:, 1] / (hw[0] - 1) - 1.0
return grid
def unnorm_pixel_grid(grid, hw=None, in_place=False):
"""
Unnormalize pixel grid to be between [0,H] and [0,W]
Parameters
----------
grid : torch.Tensor
Grid to be normalized [B,2,H,W]
hw : Tuple
Height/width for unnormalization
in_place : Bool
Whether the operation is done in place or not
Returns
-------
grid : torch.Tensor
Unnormalized grid [B,2,H,W]
"""
if hw is None:
hw = grid.shape[-2:]
if not in_place:
grid = grid.clone()
grid[:, 0] = 0.5 * (hw[1] - 1) * (grid[:, 0] + 1)
grid[:, 1] = 0.5 * (hw[0] - 1) * (grid[:, 1] + 1)
return grid
def match_scales(image, targets, num_scales,
mode='bilinear', align_corners=True):
"""
Creates multiple resolution versions of the input to match another list of tensors
Parameters
----------
image : torch.Tensor
Input image [B,?,H,W]
targets : list[torch.Tensor]
Target resolutions
num_scales : int
Number of scales to consider
mode : String
Interpolation mode
align_corners : Bool
Corner alignment flag
Returns
-------
images : list[torch.Tensor]
List containing tensors in the required resolutions
"""
# For all scales
images = []
image_shape = image.shape[-2:]
for i in range(num_scales):
target_shape = targets[i].shape
# If image shape is equal to target shape
if same_shape(image_shape, target_shape):
images.append(image)
else:
# Otherwise, interpolate
images.append(interpolate_image(
image, target_shape, mode=mode, align_corners=align_corners))
# Return scaled images
return images
def cat_channel_ones(tensor, n=1):
"""
Concatenate tensor with an extra channel of ones
Parameters
----------
tensor : torch.Tensor
Tensor to be concatenated
n : Int
Which channel will be concatenated
Returns
-------
cat_tensor : torch.Tensor
Concatenated tensor
"""
# Get tensor shape with 1 channel
shape = list(tensor.shape)
shape[n] = 1
# Return concatenation of tensor with ones
return torch.cat([tensor, torch.ones(shape,
device=tensor.device, dtype=tensor.dtype)], n)
def same_shape(shape1, shape2):
"""Checks if two shapes are the same"""
if len(shape1) != len(shape2):
return False
for i in range(len(shape1)):
if shape1[i] != shape2[i]:
return False
return True
def interpolate_image(image, shape=None, scale_factor=None, mode='bilinear',
align_corners=True, recompute_scale_factor=False):
"""
Interpolate an image to a different resolution
Parameters
----------
image : torch.Tensor
Image to be interpolated [B,?,h,w]
shape : torch.Tensor or tuple
Output shape [H,W]
scale_factor : Float
Scale factor for output shape
mode : String
Interpolation mode
align_corners : Bool
True if corners will be aligned after interpolation
recompute_scale_factor : Bool
True if scale factor is recomputed
Returns
-------
image : torch.Tensor
Interpolated image [B,?,H,W]
"""
assert shape is not None or scale_factor is not None, 'Invalid option for interpolate_image'
if mode == 'nearest':
align_corners = None
# Take last two dimensions as shape
if shape is not None:
if is_tensor(shape):
shape = shape.shape
if len(shape) > 2:
shape = shape[-2:]
# If the shapes are the same, do nothing
if same_shape(image.shape[-2:], shape):
return image
# Interpolate image to match the shape
return tfn.interpolate(image, size=shape, scale_factor=scale_factor,
mode=mode, align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor)
def check_assert(pred, gt, atol=1e-5, rtol=1e-5):
"""
Check two dictionaries with allclose assertions
Parameters
----------
pred : Dict
Dictionary with predictions
gt : Dict
Dictionary with ground-truth
atol : Float
Absolute tolerance
rtol : Float
Relative tolerance
"""
for key in gt.keys():
if key in pred.keys():
# assert key in pred and key in gt
if is_dict(pred[key]):
check_assert(pred[key], gt[key])
elif is_seq(pred[key]):
for val1, val2 in zip(pred[key], gt[key]):
if is_tensor(val1):
assert torch.allclose(val1, val2, atol=atol, rtol=rtol), \
f'Assert error in {key} : {val1.mean().item()} x {val2.mean().item()}'
else:
assert val1 == val2, \
f'Assert error in {key} : {val1} x {val2}'
else:
if is_tensor(pred[key]):
assert torch.allclose(pred[key], gt[key], atol=atol, rtol=rtol), \
f'Assert error in {key} : {pred[key].mean().item()} x {gt[key].mean().item()}'
else:
assert pred[key] == gt[key], \
f'Assert error in {key} : {pred[key]} x {gt[key]}'
def interleave(data, b):
"""Interleave data considering multiple batches"""
data_interleave = data.unsqueeze(1).expand(-1, b, *data.shape[1:])
return data_interleave.reshape(-1, *data.shape[1:])
|