Spaces:
Runtime error
Runtime error
| import math | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| # import trimesh | |
| from PIL import Image | |
| from torch import BoolTensor, FloatTensor | |
| LIST_TYPE = Union[list, np.ndarray, torch.Tensor] | |
| def list_to_pt( | |
| x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None | |
| ) -> torch.Tensor: | |
| if isinstance(x, list) or isinstance(x, np.ndarray): | |
| return torch.tensor(x, dtype=dtype, device=device) | |
| return x.to(dtype=dtype) | |
| def get_c2w( | |
| elevation_deg: LIST_TYPE, | |
| distance: LIST_TYPE, | |
| azimuth_deg: Optional[LIST_TYPE], | |
| num_views: Optional[int] = 1, | |
| device: Optional[str] = None, | |
| ) -> torch.FloatTensor: | |
| if azimuth_deg is None: | |
| assert ( | |
| num_views is not None | |
| ), "num_views must be provided if azimuth_deg is None." | |
| azimuth_deg = torch.linspace( | |
| 0, 360, num_views + 1, dtype=torch.float32, device=device | |
| )[:-1] | |
| else: | |
| num_views = len(azimuth_deg) | |
| azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) | |
| elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) | |
| camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) | |
| elevation = elevation_deg * math.pi / 180 | |
| azimuth = azimuth_deg * math.pi / 180 | |
| camera_positions = torch.stack( | |
| [ | |
| camera_distances * torch.cos(elevation) * torch.cos(azimuth), | |
| camera_distances * torch.cos(elevation) * torch.sin(azimuth), | |
| camera_distances * torch.sin(elevation), | |
| ], | |
| dim=-1, | |
| ) | |
| center = torch.zeros_like(camera_positions) | |
| up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat( | |
| num_views, 1 | |
| ) | |
| lookat = F.normalize(center - camera_positions, dim=-1) | |
| right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) | |
| up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) | |
| c2w3x4 = torch.cat( | |
| [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], | |
| dim=-1, | |
| ) | |
| c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) | |
| c2w[:, 3, 3] = 1.0 | |
| return c2w | |
| def get_projection_matrix( | |
| fovy_deg: LIST_TYPE, | |
| aspect_wh: float = 1.0, | |
| near: float = 0.1, | |
| far: float = 100.0, | |
| device: Optional[str] = None, | |
| ) -> torch.FloatTensor: | |
| fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device) | |
| batch_size = fovy_deg.shape[0] | |
| fovy = fovy_deg * math.pi / 180 | |
| tan_half_fovy = torch.tan(fovy / 2) | |
| projection_matrix = torch.zeros( | |
| batch_size, 4, 4, dtype=torch.float32, device=device | |
| ) | |
| projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy) | |
| projection_matrix[:, 1, 1] = -1 / tan_half_fovy | |
| projection_matrix[:, 2, 2] = -(far + near) / (far - near) | |
| projection_matrix[:, 2, 3] = -2 * far * near / (far - near) | |
| projection_matrix[:, 3, 2] = -1 | |
| return projection_matrix | |
| def get_orthogonal_projection_matrix( | |
| batch_size: int, | |
| left: float, | |
| right: float, | |
| bottom: float, | |
| top: float, | |
| near: float = 0.1, | |
| far: float = 100.0, | |
| device: Optional[str] = None, | |
| ) -> torch.FloatTensor: | |
| projection_matrix = torch.zeros( | |
| batch_size, 4, 4, dtype=torch.float32, device=device | |
| ) | |
| projection_matrix[:, 0, 0] = 2 / (right - left) | |
| projection_matrix[:, 1, 1] = -2 / (top - bottom) | |
| projection_matrix[:, 2, 2] = -2 / (far - near) | |
| projection_matrix[:, 0, 3] = -(right + left) / (right - left) | |
| projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom) | |
| projection_matrix[:, 2, 3] = -(far + near) / (far - near) | |
| projection_matrix[:, 3, 3] = 1 | |
| return projection_matrix | |
| class Camera: | |
| c2w: Optional[torch.FloatTensor] | |
| w2c: torch.FloatTensor | |
| proj_mtx: torch.FloatTensor | |
| mvp_mtx: torch.FloatTensor | |
| cam_pos: Optional[torch.FloatTensor] | |
| def __getitem__(self, index): | |
| if isinstance(index, int): | |
| sl = slice(index, index + 1) | |
| elif isinstance(index, slice): | |
| sl = index | |
| else: | |
| raise NotImplementedError | |
| return Camera( | |
| c2w=self.c2w[sl] if self.c2w is not None else None, | |
| w2c=self.w2c[sl], | |
| proj_mtx=self.proj_mtx[sl], | |
| mvp_mtx=self.mvp_mtx[sl], | |
| cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None, | |
| ) | |
| def to(self, device: Optional[str] = None): | |
| if self.c2w is not None: | |
| self.c2w = self.c2w.to(device) | |
| self.w2c = self.w2c.to(device) | |
| self.proj_mtx = self.proj_mtx.to(device) | |
| self.mvp_mtx = self.mvp_mtx.to(device) | |
| if self.cam_pos is not None: | |
| self.cam_pos = self.cam_pos.to(device) | |
| def __len__(self): | |
| return self.c2w.shape[0] | |
| def get_camera( | |
| elevation_deg: Optional[LIST_TYPE] = None, | |
| distance: Optional[LIST_TYPE] = None, | |
| fovy_deg: Optional[LIST_TYPE] = None, | |
| azimuth_deg: Optional[LIST_TYPE] = None, | |
| num_views: Optional[int] = 1, | |
| c2w: Optional[torch.FloatTensor] = None, | |
| w2c: Optional[torch.FloatTensor] = None, | |
| proj_mtx: Optional[torch.FloatTensor] = None, | |
| aspect_wh: float = 1.0, | |
| near: float = 0.1, | |
| far: float = 100.0, | |
| device: Optional[str] = None, | |
| ): | |
| if w2c is None: | |
| if c2w is None: | |
| c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) | |
| camera_positions = c2w[:, :3, 3] | |
| w2c = torch.linalg.inv(c2w) | |
| else: | |
| camera_positions = None | |
| c2w = None | |
| if proj_mtx is None: | |
| proj_mtx = get_projection_matrix( | |
| fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device | |
| ) | |
| mvp_mtx = proj_mtx @ w2c | |
| return Camera( | |
| c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions | |
| ) | |
| def get_orthogonal_camera( | |
| elevation_deg: LIST_TYPE, | |
| distance: LIST_TYPE, | |
| left: float, | |
| right: float, | |
| bottom: float, | |
| top: float, | |
| azimuth_deg: Optional[LIST_TYPE] = None, | |
| num_views: Optional[int] = 1, | |
| near: float = 0.1, | |
| far: float = 100.0, | |
| device: Optional[str] = None, | |
| ): | |
| c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) | |
| camera_positions = c2w[:, :3, 3] | |
| w2c = torch.linalg.inv(c2w) | |
| proj_mtx = get_orthogonal_projection_matrix( | |
| batch_size=c2w.shape[0], | |
| left=left, | |
| right=right, | |
| bottom=bottom, | |
| top=top, | |
| near=near, | |
| far=far, | |
| device=device, | |
| ) | |
| mvp_mtx = proj_mtx @ w2c | |
| return Camera( | |
| c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions | |
| ) | |