jboth commited on
Commit
7046e4a
·
verified ·
1 Parent(s): 784d43d

Upload pytorch3d_stub/pytorch3d/transforms/__init__.py with huggingface_hub

Browse files
pytorch3d_stub/pytorch3d/transforms/__init__.py CHANGED
@@ -102,6 +102,14 @@ class Transform3d:
102
  S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
103
  new_m = self._matrix @ S
104
  return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
 
 
 
 
 
 
 
 
105
  def rotate(self, R):
106
  if R.dim() == 2: R = R.unsqueeze(0)
107
  T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()
 
102
  S[:, 0, 0] = s[:, 0]; S[:, 1, 1] = s[:, 1]; S[:, 2, 2] = s[:, 2]
103
  new_m = self._matrix @ S
104
  return Transform3d(matrix=new_m, device=self.device, dtype=self.dtype)
105
+ def to(self, device=None, dtype=None):
106
+ if device is not None: self.device = device
107
+ if dtype is not None: self.dtype = dtype
108
+ self._matrix = self._matrix.to(device=device, dtype=dtype)
109
+ return self
110
+ def inverse(self):
111
+ inv_m = torch.inverse(self._matrix)
112
+ return Transform3d(matrix=inv_m, device=self.device, dtype=self.dtype)
113
  def rotate(self, R):
114
  if R.dim() == 2: R = R.unsqueeze(0)
115
  T = torch.eye(4, dtype=self.dtype, device=self.device).unsqueeze(0).expand(R.shape[0], -1, -1).clone()