Spaces:
Runtime error
Runtime error
File size: 5,630 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from vidar.geometry.pose_utils import invert_pose, pose_vec2mat, to_global_pose, euler2mat
from vidar.utils.types import is_int
def from_dict_sample(T, to_global=False, zero_origin=False, to_matrix=False):
"""
Create poses from a sample dictionary
Parameters
----------
T : Dict
Dictionary containing input poses [B,4,4]
to_global : Bool
Whether poses should be converted to global frame of reference
zero_origin : Bool
Whether the target camera should be the center of the frame of reference
to_matrix : Bool
Whether output poses should be classes or tensors
Returns
-------
pose : Dict
Dictionary containing output poses
"""
pose = {key: Pose(val) for key, val in T.items()}
if to_global:
pose = to_global_pose(pose, zero_origin=zero_origin)
if to_matrix:
pose = {key: val.T for key, val in pose.items()}
return pose
def from_dict_batch(T, **kwargs):
"""Create poses from a batch dictionary"""
pose_batch = [from_dict_sample({key: val[b] for key, val in T.items()}, **kwargs)
for b in range(T[0].shape[0])]
return {key: torch.stack([v[key] for v in pose_batch], 0) for key in pose_batch[0]}
class Pose:
"""
Pose class for 3D operations
Parameters
----------
T : torch.Tensor or Int
Transformation matrix [B,4,4], or batch size (poses initialized as identity)
"""
def __init__(self, T=1):
if is_int(T):
T = torch.eye(4).repeat(T, 1, 1)
self.T = T if T.dim() == 3 else T.unsqueeze(0)
def __len__(self):
"""Return batch size"""
return len(self.T)
def __getitem__(self, i):
"""Return batch-wise pose"""
return Pose(self.T[[i]])
def __mul__(self, data):
"""Transforms data (pose or 3D points)"""
if isinstance(data, Pose):
return Pose(self.T.bmm(data.T))
elif isinstance(data, torch.Tensor):
return self.T[:, :3, :3].bmm(data) + self.T[:, :3, -1].unsqueeze(-1)
else:
raise NotImplementedError()
def detach(self):
"""Return detached pose"""
return Pose(self.T.detach())
@property
def shape(self):
"""Return pose shape"""
return self.T.shape
@property
def device(self):
"""Return pose device"""
return self.T.device
@property
def dtype(self):
"""Return pose type"""
return self.T.dtype
@classmethod
def identity(cls, N=1, device=None, dtype=torch.float):
"""Initializes as a [4,4] identity matrix"""
return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1]))
@staticmethod
def from_dict(T, to_global=False, zero_origin=False, to_matrix=False):
"""Create poses from a dictionary"""
if T[0].dim() == 3:
return from_dict_sample(T, to_global=to_global, zero_origin=zero_origin, to_matrix=to_matrix)
elif T[0].dim() == 4:
return from_dict_batch(T, to_global=to_global, zero_origin=zero_origin, to_matrix=True)
@classmethod
def from_vec(cls, vec, mode):
"""Initializes from a [B,6] batch vector"""
mat = pose_vec2mat(vec, mode)
pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1])
pose[:, :3, :3] = mat[:, :3, :3]
pose[:, :3, -1] = mat[:, :3, -1]
return cls(pose)
def repeat(self, *args, **kwargs):
"""Repeats the transformation matrix multiple times"""
self.T = self.T.repeat(*args, **kwargs)
return self
def inverse(self):
"""Returns a new Pose that is the inverse of this one"""
return Pose(invert_pose(self.T))
def to(self, *args, **kwargs):
"""Copy pose to device"""
self.T = self.T.to(*args, **kwargs)
return self
def cuda(self, *args, **kwargs):
"""Copy pose to CUDA"""
self.to('cuda')
return self
def translate(self, xyz):
"""Translate pose"""
self.T[:, :3, -1] = self.T[:, :3, -1] + xyz.to(self.device)
return self
def rotate(self, rpw):
"""Rotate pose"""
rot = euler2mat(rpw)
T = invert_pose(self.T).clone()
T[:, :3, :3] = T[:, :3, :3] @ rot.to(self.device)
self.T = invert_pose(T)
return self
def rotateRoll(self, r):
"""Rotate pose in the roll axis"""
return self.rotate(torch.tensor([[0, 0, r]]))
def rotatePitch(self, p):
"""Rotate pose in the pitcv axis"""
return self.rotate(torch.tensor([[p, 0, 0]]))
def rotateYaw(self, w):
"""Rotate pose in the yaw axis"""
return self.rotate(torch.tensor([[0, w, 0]]))
def translateForward(self, t):
"""Translate pose forward"""
return self.translate(torch.tensor([[0, 0, -t]]))
def translateBackward(self, t):
"""Translate pose backward"""
return self.translate(torch.tensor([[0, 0, +t]]))
def translateLeft(self, t):
"""Translate pose left"""
return self.translate(torch.tensor([[+t, 0, 0]]))
def translateRight(self, t):
"""Translate pose right"""
return self.translate(torch.tensor([[-t, 0, 0]]))
def translateUp(self, t):
"""Translate pose up"""
return self.translate(torch.tensor([[0, +t, 0]]))
def translateDown(self, t):
"""Translate pose down"""
return self.translate(torch.tensor([[0, -t, 0]]))
|