Spaces:
Runtime error
Runtime error
| import bisect | |
| import math | |
| import random | |
| from dataclasses import dataclass, field | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset, IterableDataset | |
| import threestudio | |
| from threestudio import register | |
| from threestudio.utils.base import Updateable | |
| from threestudio.utils.config import parse_structured | |
| from threestudio.utils.misc import get_device | |
| from threestudio.utils.ops import ( | |
| get_mvp_matrix, | |
| get_projection_matrix, | |
| get_ray_directions, | |
| get_rays, | |
| ) | |
| from threestudio.utils.typing import * | |
| import os | |
| import numpy as np | |
| def safe_normalize(x, eps=1e-20): | |
| return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) | |
| def convert_camera_to_world_transform(transform): | |
| # 将右手坐标系的变换矩阵转换为左手坐标系 | |
| # 复制原始变换矩阵 | |
| converted_transform = transform.clone() | |
| # 反转观察方向(将平移分量的第三个元素乘以-1) | |
| converted_transform[:, 2] *= -1 | |
| # 交换第一行和第三行 | |
| converted_transform[[0, 2], :] = converted_transform[[2, 0], :] | |
| return converted_transform | |
| def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0])): | |
| theta = theta / 180 * np.pi | |
| phi = phi / 180 * np.pi | |
| centers = torch.stack([ | |
| radius * torch.sin(theta) * torch.sin(phi), | |
| radius * torch.cos(theta), | |
| radius * torch.sin(theta) * torch.cos(phi), | |
| ], dim=-1) # [B, 3] | |
| # lookat | |
| forward_vector = safe_normalize(centers) | |
| up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1) | |
| right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) | |
| up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) | |
| poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1) | |
| poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) | |
| poses[:, :3, 3] = centers | |
| return poses | |
| trans_t = lambda t : torch.Tensor([ | |
| [1,0,0,0], | |
| [0,1,0,0], | |
| [0,0,1,t], | |
| [0,0,0,1]]).float() | |
| rot_phi = lambda phi : torch.Tensor([ | |
| [1,0,0,0], | |
| [0,np.cos(phi),-np.sin(phi),0], | |
| [0,np.sin(phi), np.cos(phi),0], | |
| [0,0,0,1]]).float() | |
| rot_theta = lambda th : torch.Tensor([ | |
| [np.cos(th),0,-np.sin(th),0], | |
| [0,1,0,0], | |
| [np.sin(th),0, np.cos(th),0], | |
| [0,0,0,1]]).float() | |
| def rodrigues_mat_to_rot(R): | |
| eps =1e-16 | |
| trc = np.trace(R) | |
| trc2 = (trc - 1.)/ 2. | |
| #sinacostrc2 = np.sqrt(1 - trc2 * trc2) | |
| s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]) | |
| if (1 - trc2 * trc2) >= eps: | |
| tHeta = np.arccos(trc2) | |
| tHetaf = tHeta / (2 * (np.sin(tHeta))) | |
| else: | |
| tHeta = np.real(np.arccos(trc2)) | |
| tHetaf = 0.5 / (1 - tHeta / 6) | |
| omega = tHetaf * s | |
| return omega | |
| def rodrigues_rot_to_mat(r): | |
| wx,wy,wz = r | |
| theta = np.sqrt(wx * wx + wy * wy + wz * wz) | |
| a = np.cos(theta) | |
| b = (1 - np.cos(theta)) / (theta*theta) | |
| c = np.sin(theta) / theta | |
| R = np.zeros([3,3]) | |
| R[0, 0] = a + b * (wx * wx) | |
| R[0, 1] = b * wx * wy - c * wz | |
| R[0, 2] = b * wx * wz + c * wy | |
| R[1, 0] = b * wx * wy + c * wz | |
| R[1, 1] = a + b * (wy * wy) | |
| R[1, 2] = b * wy * wz - c * wx | |
| R[2, 0] = b * wx * wz - c * wy | |
| R[2, 1] = b * wz * wy + c * wx | |
| R[2, 2] = a + b * (wz * wz) | |
| return R | |
| def pose_spherical(theta, phi, radius): | |
| c2w = trans_t(radius) | |
| c2w = rot_phi(phi/180.*np.pi) @ c2w | |
| c2w = rot_theta(theta/180.*np.pi) @ c2w | |
| c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w | |
| return c2w | |
| def convert_camera_pose(camera_pose): | |
| # Clone the tensor to avoid in-place operations | |
| colmap_pose = camera_pose.clone() | |
| # Extract rotation and translation components | |
| rotation = colmap_pose[:, :3, :3] | |
| translation = colmap_pose[:, :3, 3] | |
| # Change rotation orientation | |
| rotation[:, 0, :] *= -1 | |
| rotation[:, 1, :] *= -1 | |
| # Change translation position | |
| translation[:, 0] *= -1 | |
| translation[:, 1] *= -1 | |
| return colmap_pose | |
| def convert_camera_pose(camera_pose): | |
| # Clone the tensor to avoid in-place operations | |
| colmap_pose = camera_pose.clone() | |
| # Extract rotation and translation components | |
| rotation = colmap_pose[:, :3, :3] | |
| translation = colmap_pose[:, :3, 3] | |
| # Change rotation orientation | |
| rotation[:, 0, :] *= -1 | |
| rotation[:, 1, :] *= -1 | |
| # Change translation position | |
| translation[:, 0] *= -1 | |
| translation[:, 1] *= -1 | |
| return colmap_pose | |
| class RandomCameraDataModuleConfig: | |
| # height, width, and batch_size should be Union[int, List[int]] | |
| # but OmegaConf does not support Union of containers | |
| height: Any = 512 | |
| width: Any = 512 | |
| batch_size: Any = 1 | |
| resolution_milestones: List[int] = field(default_factory=lambda: []) | |
| eval_height: int = 512 | |
| eval_width: int = 512 | |
| eval_batch_size: int = 1 | |
| n_val_views: int = 1 | |
| n_test_views: int = 120 | |
| elevation_range: Tuple[float, float] = (-10, 60) | |
| azimuth_range: Tuple[float, float] = (-180, 180) | |
| camera_distance_range: Tuple[float, float] = (4.,6.) | |
| fovy_range: Tuple[float, float] = ( | |
| 40, | |
| 70, | |
| ) # in degrees, in vertical direction (along height) | |
| camera_perturb: float = 0. | |
| center_perturb: float = 0. | |
| up_perturb: float = 0.0 | |
| light_position_perturb: float = 1.0 | |
| light_distance_range: Tuple[float, float] = (0.8, 1.5) | |
| eval_elevation_deg: float = 15.0 | |
| eval_camera_distance: float = 6. | |
| eval_fovy_deg: float = 70.0 | |
| light_sample_strategy: str = "dreamfusion" | |
| batch_uniform_azimuth: bool = True | |
| progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy | |
| class RandomCameraIterableDataset(IterableDataset, Updateable): | |
| def __init__(self, cfg: Any) -> None: | |
| super().__init__() | |
| self.cfg: RandomCameraDataModuleConfig = cfg | |
| self.heights: List[int] = ( | |
| [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height | |
| ) | |
| self.widths: List[int] = ( | |
| [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width | |
| ) | |
| self.batch_sizes: List[int] = ( | |
| [self.cfg.batch_size] | |
| if isinstance(self.cfg.batch_size, int) | |
| else self.cfg.batch_size | |
| ) | |
| assert len(self.heights) == len(self.widths) == len(self.batch_sizes) | |
| self.resolution_milestones: List[int] | |
| if ( | |
| len(self.heights) == 1 | |
| and len(self.widths) == 1 | |
| and len(self.batch_sizes) == 1 | |
| ): | |
| if len(self.cfg.resolution_milestones) > 0: | |
| threestudio.warn( | |
| "Ignoring resolution_milestones since height and width are not changing" | |
| ) | |
| self.resolution_milestones = [-1] | |
| else: | |
| assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 | |
| self.resolution_milestones = [-1] + self.cfg.resolution_milestones | |
| self.directions_unit_focals = [ | |
| get_ray_directions(H=height, W=width, focal=1.0) | |
| for (height, width) in zip(self.heights, self.widths) | |
| ] | |
| self.height: int = self.heights[0] | |
| self.width: int = self.widths[0] | |
| self.batch_size: int = self.batch_sizes[0] | |
| self.directions_unit_focal = self.directions_unit_focals[0] | |
| self.elevation_range = self.cfg.elevation_range | |
| self.azimuth_range = self.cfg.azimuth_range | |
| self.camera_distance_range = self.cfg.camera_distance_range | |
| self.fovy_range = self.cfg.fovy_range | |
| def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
| size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 | |
| self.height = self.heights[size_ind] | |
| self.width = self.widths[size_ind] | |
| self.batch_size = self.batch_sizes[size_ind] | |
| self.directions_unit_focal = self.directions_unit_focals[size_ind] | |
| threestudio.debug( | |
| f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}" | |
| ) | |
| # progressive view | |
| self.progressive_view(global_step) | |
| def __iter__(self): | |
| while True: | |
| yield {} | |
| def progressive_view(self, global_step): | |
| pass | |
| # r = min(1.0, global_step / (self.cfg.progressive_until + 1)) | |
| # self.elevation_range = [ | |
| # (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0], | |
| # (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1], | |
| # ] | |
| # self.azimuth_range = [ | |
| # (1 - r) * 0.0 + r * self.cfg.azimuth_range[0], | |
| # (1 - r) * 0.0 + r * self.cfg.azimuth_range[1], | |
| # ] | |
| # self.camera_distance_range = [ | |
| # (1 - r) * self.cfg.eval_camera_distance | |
| # + r * self.cfg.camera_distance_range[0], | |
| # (1 - r) * self.cfg.eval_camera_distance | |
| # + r * self.cfg.camera_distance_range[1], | |
| # ] | |
| # self.fovy_range = [ | |
| # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0], | |
| # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1], | |
| # ] | |
| def collate(self, batch) -> Dict[str, Any]: | |
| # sample elevation angles | |
| elevation_deg: Float[Tensor, "B"] | |
| elevation: Float[Tensor, "B"] | |
| if random.random() < 0.5: | |
| # sample elevation angles uniformly with a probability 0.5 (biased towards poles) | |
| elevation_deg = ( | |
| torch.rand(self.batch_size) | |
| * (self.elevation_range[1] - self.elevation_range[0]) | |
| + self.elevation_range[0] | |
| ) | |
| elevation = elevation_deg * math.pi / 180 | |
| else: | |
| # otherwise sample uniformly on sphere | |
| elevation_range_percent = [ | |
| (self.elevation_range[0] + 90.0) / 180.0, | |
| (self.elevation_range[1] + 90.0) / 180.0, | |
| ] | |
| # inverse transform sampling | |
| elevation = torch.asin( | |
| 2 | |
| * ( | |
| torch.rand(self.batch_size) | |
| * (elevation_range_percent[1] - elevation_range_percent[0]) | |
| + elevation_range_percent[0] | |
| ) | |
| - 1.0 | |
| ) | |
| elevation_deg = elevation / math.pi * 180.0 | |
| # sample azimuth angles from a uniform distribution bounded by azimuth_range | |
| azimuth_deg: Float[Tensor, "B"] | |
| if self.cfg.batch_uniform_azimuth: | |
| # ensures sampled azimuth angles in a batch cover the whole range | |
| azimuth_deg = ( | |
| torch.rand(self.batch_size) + torch.arange(self.batch_size) | |
| ) / self.batch_size * ( | |
| self.azimuth_range[1] - self.azimuth_range[0] | |
| ) + self.azimuth_range[ | |
| 0 | |
| ] | |
| else: | |
| # simple random sampling | |
| azimuth_deg = ( | |
| torch.rand(self.batch_size) | |
| * (self.azimuth_range[1] - self.azimuth_range[0]) | |
| + self.azimuth_range[0] | |
| ) | |
| azimuth = azimuth_deg * math.pi / 180 | |
| # sample distances from a uniform distribution bounded by distance_range | |
| camera_distances: Float[Tensor, "B"] = ( | |
| torch.rand(self.batch_size) | |
| * (self.camera_distance_range[1] - self.camera_distance_range[0]) | |
| + self.camera_distance_range[0] | |
| ) | |
| # convert spherical coordinates to cartesian coordinates | |
| # right hand coordinate system, x back, y right, z up | |
| # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) | |
| camera_positions: Float[Tensor, "B 3"] = torch.stack( | |
| [ | |
| camera_distances * torch.cos(elevation) * torch.cos(azimuth), | |
| camera_distances * torch.cos(elevation) * torch.sin(azimuth), | |
| camera_distances * torch.sin(elevation), | |
| ], | |
| dim=-1, | |
| ) | |
| # default scene center at origin | |
| center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) | |
| # default camera up direction as +z | |
| up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ | |
| None, : | |
| ].repeat(self.batch_size, 1) | |
| # sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb] | |
| camera_perturb: Float[Tensor, "B 3"] = ( | |
| torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb | |
| - self.cfg.camera_perturb | |
| ) | |
| camera_positions = camera_positions + camera_perturb | |
| # sample center perturbations from a normal distribution with mean 0 and std center_perturb | |
| center_perturb: Float[Tensor, "B 3"] = ( | |
| torch.randn(self.batch_size, 3) * self.cfg.center_perturb | |
| ) | |
| center = center + center_perturb | |
| # sample up perturbations from a normal distribution with mean 0 and std up_perturb | |
| up_perturb: Float[Tensor, "B 3"] = ( | |
| torch.randn(self.batch_size, 3) * self.cfg.up_perturb | |
| ) | |
| up = up + up_perturb | |
| # sample fovs from a uniform distribution bounded by fov_range | |
| fovy_deg: Float[Tensor, "B"] = ( | |
| torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0]) | |
| + self.fovy_range[0] | |
| ) | |
| fovy = fovy_deg * math.pi / 180 | |
| # sample light distance from a uniform distribution bounded by light_distance_range | |
| light_distances: Float[Tensor, "B"] = ( | |
| torch.rand(self.batch_size) | |
| * (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0]) | |
| + self.cfg.light_distance_range[0] | |
| ) | |
| if self.cfg.light_sample_strategy == "dreamfusion" or self.cfg.light_sample_strategy == "dreamfusion3dgs": | |
| # sample light direction from a normal distribution with mean camera_position and std light_position_perturb | |
| light_direction: Float[Tensor, "B 3"] = F.normalize( | |
| camera_positions | |
| + torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb, | |
| dim=-1, | |
| ) | |
| # get light position by scaling light direction by light distance | |
| light_positions: Float[Tensor, "B 3"] = ( | |
| light_direction * light_distances[:, None] | |
| ) | |
| elif self.cfg.light_sample_strategy == "magic3d": | |
| # sample light direction within restricted angle range (pi/3) | |
| local_z = F.normalize(camera_positions, dim=-1) | |
| local_x = F.normalize( | |
| torch.stack( | |
| [local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])], | |
| dim=-1, | |
| ), | |
| dim=-1, | |
| ) | |
| local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1) | |
| rot = torch.stack([local_x, local_y, local_z], dim=-1) | |
| light_azimuth = ( | |
| torch.rand(self.batch_size) * math.pi * 2 - math.pi | |
| ) # [-pi, pi] | |
| light_elevation = ( | |
| torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6 | |
| ) # [pi/6, pi/2] | |
| light_positions_local = torch.stack( | |
| [ | |
| light_distances | |
| * torch.cos(light_elevation) | |
| * torch.cos(light_azimuth), | |
| light_distances | |
| * torch.cos(light_elevation) | |
| * torch.sin(light_azimuth), | |
| light_distances * torch.sin(light_elevation), | |
| ], | |
| dim=-1, | |
| ) | |
| light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0] | |
| else: | |
| raise ValueError( | |
| f"Unknown light sample strategy: {self.cfg.light_sample_strategy}" | |
| ) | |
| lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) | |
| right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) | |
| up = F.normalize(torch.cross(right, lookat), dim=-1) | |
| c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( | |
| [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], | |
| dim=-1, | |
| ) | |
| c2w: Float[Tensor, "B 4 4"] = torch.cat( | |
| [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 | |
| ) | |
| c2w[:, 3, 3] = 1.0 | |
| # get directions by dividing directions_unit_focal by focal length | |
| focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy) | |
| directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[ | |
| None, :, :, : | |
| ].repeat(self.batch_size, 1, 1, 1) | |
| directions[:, :, :, :2] = ( | |
| directions[:, :, :, :2] / focal_length[:, None, None, None] | |
| ) | |
| proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( | |
| fovy, self.width / self.height, 0.1, 1000.0 | |
| ) # FIXME: hard-coded near and far | |
| mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, proj_mtx) | |
| c2w_3dgs = [] | |
| for id in range(self.batch_size): | |
| render_pose = pose_spherical( azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id]) | |
| # print(azimuth_deg[id] , -elevation_deg[id], camera_distances[id]*2.0) | |
| # print(render_pose) | |
| matrix = torch.linalg.inv(render_pose) | |
| # R = -np.transpose(matrix[:3,:3]) | |
| # R = -np.transpose(matrix[:3,:3]) | |
| R = -torch.transpose(matrix[:3,:3], 0, 1) | |
| R[:,0] = -R[:,0] | |
| T = -matrix[:3, 3] | |
| c2w_single = torch.cat([R, T[:,None]], 1) | |
| c2w_single = torch.cat([c2w_single, torch.tensor([[0,0,0,1]])], 0) | |
| # c2w_single = convert_camera_to_world_transform(c2w_single) | |
| c2w_3dgs.append(c2w_single) | |
| c2w_3dgs = torch.stack(c2w_3dgs, 0) | |
| return { | |
| "mvp_mtx": mvp_mtx, | |
| "camera_positions": camera_positions, | |
| "c2w": c2w, | |
| "c2w_3dgs":c2w_3dgs, | |
| "light_positions": light_positions, | |
| "elevation": elevation_deg, | |
| "azimuth": azimuth_deg, | |
| "camera_distances": camera_distances, | |
| "height": self.height, | |
| "width": self.width, | |
| "fovy":fovy, | |
| } | |
| class RandomCameraDataset(Dataset): | |
| def __init__(self, cfg: Any, split: str) -> None: | |
| super().__init__() | |
| self.cfg: RandomCameraDataModuleConfig = cfg | |
| self.split = split | |
| if split == "val": | |
| self.n_views = self.cfg.n_val_views | |
| else: | |
| self.n_views = self.cfg.n_test_views | |
| azimuth_deg: Float[Tensor, "B"] | |
| if self.split == "val": | |
| # make sure the first and last view are not the same | |
| azimuth_deg = torch.linspace(-180., 180.0, self.n_views + 1)[: self.n_views] | |
| else: | |
| azimuth_deg = torch.linspace(-180., 180.0, self.n_views) | |
| elevation_deg: Float[Tensor, "B"] = torch.full_like( | |
| azimuth_deg, self.cfg.eval_elevation_deg | |
| ) | |
| camera_distances: Float[Tensor, "B"] = torch.full_like( | |
| elevation_deg, self.cfg.eval_camera_distance | |
| ) | |
| elevation = elevation_deg * math.pi / 180 | |
| azimuth = azimuth_deg * math.pi / 180 | |
| # convert spherical coordinates to cartesian coordinates | |
| # right hand coordinate system, x back, y right, z up | |
| # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) | |
| camera_positions: Float[Tensor, "B 3"] = torch.stack( | |
| [ | |
| camera_distances * torch.cos(elevation) * torch.cos(azimuth), | |
| camera_distances * torch.cos(elevation) * torch.sin(azimuth), | |
| camera_distances * torch.sin(elevation), | |
| ], | |
| dim=-1, | |
| ) | |
| # default scene center at origin | |
| center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) | |
| # default camera up direction as +z | |
| up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ | |
| None, : | |
| ].repeat(self.cfg.eval_batch_size, 1) | |
| fovy_deg: Float[Tensor, "B"] = torch.full_like( | |
| elevation_deg, self.cfg.eval_fovy_deg | |
| ) | |
| fovy = fovy_deg * math.pi / 180 | |
| light_positions: Float[Tensor, "B 3"] = camera_positions | |
| lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) | |
| right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) | |
| up = F.normalize(torch.cross(right, lookat), dim=-1) | |
| c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( | |
| [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], | |
| dim=-1, | |
| ) | |
| c2w: Float[Tensor, "B 4 4"] = torch.cat( | |
| [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 | |
| ) | |
| c2w[:, 3, 3] = 1.0 | |
| # get directions by dividing directions_unit_focal by focal length | |
| focal_length: Float[Tensor, "B"] = ( | |
| 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) | |
| ) | |
| directions_unit_focal = get_ray_directions( | |
| H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0 | |
| ) | |
| directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ | |
| None, :, :, : | |
| ].repeat(self.n_views, 1, 1, 1) | |
| directions[:, :, :, :2] = ( | |
| directions[:, :, :, :2] / focal_length[:, None, None, None] | |
| ) | |
| proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( | |
| fovy, self.cfg.eval_width / self.cfg.eval_height, 0.1, 1000.0 | |
| ) # FIXME: hard-coded near and far | |
| mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, proj_mtx) | |
| c2w_3dgs = [] | |
| for id in range(self.n_views): | |
| render_pose = pose_spherical( azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id]) | |
| matrix = torch.linalg.inv(render_pose) | |
| # R = -np.transpose(matrix[:3,:3]) | |
| # R = -np.transpose(matrix[:3,:3]) | |
| R = -torch.transpose(matrix[:3,:3], 0, 1) | |
| R[:,0] = -R[:,0] | |
| T = -matrix[:3, 3] | |
| c2w_single = torch.cat([R, T[:,None]], 1) | |
| c2w_single = torch.cat([c2w_single, torch.tensor([[0,0,0,1]])], 0) | |
| # c2w_single = convert_camera_to_world_transform(c2w_single) | |
| c2w_3dgs.append(c2w_single) | |
| c2w_3dgs = torch.stack(c2w_3dgs, 0) | |
| self.mvp_mtx = mvp_mtx | |
| self.c2w = c2w | |
| self.c2w_3dgs = c2w_3dgs | |
| self.camera_positions = camera_positions | |
| self.light_positions = light_positions | |
| self.elevation, self.azimuth = elevation, azimuth | |
| self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg | |
| self.camera_distances = camera_distances | |
| self.fovy = fovy | |
| def __len__(self): | |
| return self.n_views | |
| def __getitem__(self, index): | |
| return { | |
| "index": index, | |
| "mvp_mtx": self.mvp_mtx[index], | |
| "c2w": self.c2w[index], | |
| "c2w_3dgs": self.c2w_3dgs[index], | |
| "camera_positions": self.camera_positions[index], | |
| "light_positions": self.light_positions[index], | |
| "elevation": self.elevation_deg[index], | |
| "azimuth": self.azimuth_deg[index], | |
| "camera_distances": self.camera_distances[index], | |
| "height": self.cfg.eval_height, | |
| "width": self.cfg.eval_width, | |
| "fovy":self.fovy[index], | |
| } | |
| def collate(self, batch): | |
| batch = torch.utils.data.default_collate(batch) | |
| batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) | |
| return batch | |
| class RandomCameraDataModule(pl.LightningDataModule): | |
| cfg: RandomCameraDataModuleConfig | |
| def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(RandomCameraDataModuleConfig, cfg) | |
| def setup(self, stage=None) -> None: | |
| if stage in [None, "fit"]: | |
| self.train_dataset = RandomCameraIterableDataset(self.cfg) | |
| if stage in [None, "fit", "validate"]: | |
| self.val_dataset = RandomCameraDataset(self.cfg, "val") | |
| if stage in [None, "test", "predict"]: | |
| self.test_dataset = RandomCameraDataset(self.cfg, "test") | |
| def prepare_data(self): | |
| pass | |
| def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: | |
| return DataLoader( | |
| dataset, | |
| # very important to disable multi-processing if you want to change self attributes at runtime! | |
| # (for example setting self.width and self.height in update_step) | |
| num_workers=0, # type: ignore | |
| batch_size=batch_size, | |
| collate_fn=collate_fn, | |
| ) | |
| def train_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate | |
| ) | |
| # return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate) | |
| def test_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate | |
| ) | |
| def predict_dataloader(self) -> DataLoader: | |
| return self.general_loader( | |
| self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate | |
| ) | |