Spaces:
Runtime error
Runtime error
File size: 1,547 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
import torchvision.transforms as transforms
from vidar.utils.decorators import iterate1
@iterate1
def to_tensor(matrix, tensor_type='torch.FloatTensor'):
"""Casts a matrix to a torch.Tensor"""
return torch.Tensor(matrix).type(tensor_type)
@iterate1
def to_tensor_image(image, tensor_type='torch.FloatTensor'):
"""Casts an image to a torch.Tensor"""
transform = transforms.ToTensor()
return transform(image).type(tensor_type)
@iterate1
def to_tensor_sample(sample, tensor_type='torch.FloatTensor'):
"""
Casts the keys of sample to tensors.
Parameters
----------
sample : Dict
Input sample
tensor_type : String
Type of tensor we are casting to
Returns
-------
sample : Dict
Sample with keys cast as tensors
"""
# Convert using torchvision
keys = ['rgb', 'mask', 'input_depth', 'depth', 'disparity',
'optical_flow', 'scene_flow']
for key_sample, val_sample in sample.items():
for key in keys:
if key in key_sample:
sample[key_sample] = to_tensor_image(val_sample, tensor_type)
# Convert from numpy
keys = ['intrinsics', 'extrinsics', 'pose', 'pointcloud', 'semantic']
for key_sample, val_sample in sample.items():
for key in keys:
if key in key_sample:
sample[key_sample] = to_tensor(val_sample, tensor_type)
# Return converted sample
return sample
|