Spaces:
Build error
Build error
| import os | |
| import json | |
| import random | |
| from PIL import Image | |
| import torch | |
| from typing import List, Tuple, Union | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import torchvision.transforms as T | |
| from onediffusion.dataset.utils import * | |
| import glob | |
| from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras | |
| from onediffusion.dataset.transforms import CenterCropResizeImage | |
| from pytorch3d.renderer import PerspectiveCameras | |
| import numpy as np | |
| def _cameras_from_opencv_projection( | |
| R: torch.Tensor, | |
| tvec: torch.Tensor, | |
| camera_matrix: torch.Tensor, | |
| image_size: torch.Tensor, | |
| do_normalize_cameras, | |
| normalize_scale, | |
| ) -> PerspectiveCameras: | |
| focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) | |
| principal_point = camera_matrix[:, :2, 2] | |
| # Retype the image_size correctly and flip to width, height. | |
| image_size_wh = image_size.to(R).flip(dims=(1,)) | |
| # Screen to NDC conversion: | |
| # For non square images, we scale the points such that smallest side | |
| # has range [-1, 1] and the largest side has range [-u, u], with u > 1. | |
| # This convention is consistent with the PyTorch3D renderer, as well as | |
| # the transformation function `get_ndc_to_screen_transform`. | |
| scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 | |
| scale = scale.expand(-1, 2) | |
| c0 = image_size_wh / 2.0 | |
| # Get the PyTorch3D focal length and principal point. | |
| focal_pytorch3d = focal_length / scale | |
| p0_pytorch3d = -(principal_point - c0) / scale | |
| # For R, T we flip x, y axes (opencv screen space has an opposite | |
| # orientation of screen axes). | |
| # We also transpose R (opencv multiplies points from the opposite=left side). | |
| R_pytorch3d = R.clone().permute(0, 2, 1) | |
| T_pytorch3d = tvec.clone() | |
| R_pytorch3d[:, :, :2] *= -1 | |
| T_pytorch3d[:, :2] *= -1 | |
| cams = PerspectiveCameras( | |
| R=R_pytorch3d, | |
| T=T_pytorch3d, | |
| focal_length=focal_pytorch3d, | |
| principal_point=p0_pytorch3d, | |
| image_size=image_size, | |
| device=R.device, | |
| ) | |
| if do_normalize_cameras: | |
| cams, _ = normalize_cameras(cams, scale=normalize_scale) | |
| cams = first_camera_transform(cams, rotation_only=False) | |
| return cams | |
| def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0): | |
| cameras = _cameras_from_opencv_projection( | |
| R=Rs, | |
| tvec=Ts, | |
| camera_matrix=Ks, | |
| image_size=sizes, | |
| do_normalize_cameras=do_normalize_cameras, | |
| normalize_scale=normalize_scale | |
| ) | |
| rays_embedding = cameras_to_rays( | |
| cameras=cameras, | |
| num_patches_x=target_size, | |
| num_patches_y=target_size, | |
| crop_parameters=None, | |
| use_plucker=use_plucker | |
| ) | |
| return rays_embedding.rays | |
| def convert_rgba_to_rgb_white_bg(image): | |
| """Convert RGBA image to RGB with white background""" | |
| if image.mode == 'RGBA': | |
| # Create a white background | |
| background = Image.new('RGBA', image.size, (255, 255, 255, 255)) | |
| # Composite the image onto the white background | |
| return Image.alpha_composite(background, image).convert('RGB') | |
| return image.convert('RGB') | |
| class MultiviewDataset(Dataset): | |
| def __init__( | |
| self, | |
| scene_folders: str, | |
| samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range | |
| transform=None, | |
| caption_keys: Union[str, List] = "caption", | |
| multiscale=False, | |
| aspect_ratio_type=ASPECT_RATIO_512, | |
| c2w_scaling=1.7, | |
| default_max_distance=1, # default max distance from all camera of a scene , | |
| do_normalize=True, # whether normalize translation of c2w with max_distance | |
| swap_xz=False, # whether swap x and z axis of 3D scenes | |
| valid_paths: str = "", | |
| frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different | |
| ): | |
| if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list): | |
| samples_per_set = (samples_per_set, samples_per_set) | |
| self.samples_range = samples_per_set # Tuple of (min_samples, max_samples) | |
| self.transform = transform | |
| self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys] | |
| self.aspect_ratio = aspect_ratio_type | |
| self.scene_folders = sorted(glob.glob(scene_folders)) | |
| # filter out scene folders that do not have transforms.json | |
| self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders)) | |
| # if valid_paths.txt exists, only use paths in that file | |
| if os.path.exists(valid_paths): | |
| with open(valid_paths, 'r') as f: | |
| valid_scene_folders = f.read().splitlines() | |
| self.scene_folders = sorted(valid_scene_folders) | |
| self.c2w_scaling = c2w_scaling | |
| self.do_normalize = do_normalize | |
| self.default_max_distance = default_max_distance | |
| self.swap_xz = swap_xz | |
| self.frame_sliding_windows = frame_sliding_windows | |
| if multiscale: | |
| assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880] | |
| if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]: | |
| self.interpolate_model = T.InterpolationMode.LANCZOS | |
| self.ratio_index = {} | |
| self.ratio_nums = {} | |
| for k, v in self.aspect_ratio.items(): | |
| self.ratio_index[float(k)] = [] # used for self.getitem | |
| self.ratio_nums[float(k)] = 0 # used for batch-sampler | |
| def __len__(self): | |
| return len(self.scene_folders) | |
| def __getitem__(self, idx): | |
| try: | |
| scene_path = self.scene_folders[idx] | |
| if os.path.exists(os.path.join(scene_path, "images")): | |
| image_folder = os.path.join(scene_path, "images") | |
| downscale_factor = 1 | |
| elif os.path.exists(os.path.join(scene_path, "images_4")): | |
| image_folder = os.path.join(scene_path, "images_4") | |
| downscale_factor = 1 / 4 | |
| elif os.path.exists(os.path.join(scene_path, "images_8")): | |
| image_folder = os.path.join(scene_path, "images_8") | |
| downscale_factor = 1 / 8 | |
| else: | |
| raise NotImplementedError | |
| json_path = os.path.join(scene_path, "transforms.json") | |
| caption_path = os.path.join(scene_path, "caption.json") | |
| image_files = os.listdir(image_folder) | |
| with open(json_path, 'r') as f: | |
| json_data = json.load(f) | |
| height, width = json_data['h'], json_data['w'] | |
| dh, dw = int(height * downscale_factor), int(width * downscale_factor) | |
| fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor | |
| cx = dw // 2 | |
| cy = dh // 2 | |
| frame_list = json_data['frames'] | |
| # Randomly select number of samples | |
| samples_per_set = random.randint(self.samples_range[0], self.samples_range[1]) | |
| # uniformly for all scenes | |
| if self.frame_sliding_windows is None: | |
| selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list))) | |
| # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles) | |
| else: | |
| # Determine the starting index of the sliding window | |
| if len(frame_list) <= self.frame_sliding_windows: | |
| # If the frame list is smaller than or equal to X, use the entire list | |
| window_start = 0 | |
| window_end = len(frame_list) | |
| else: | |
| # Randomly select a starting point for the window | |
| window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows) | |
| window_end = window_start + self.frame_sliding_windows | |
| # Get the indices within the sliding window | |
| window_indices = list(range(window_start, window_end)) | |
| # Randomly sample indices from the window | |
| selected_indices = random.sample(window_indices, samples_per_set) | |
| image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices] | |
| image_paths = [os.path.join(image_folder, file) for file in image_files] | |
| # Load images and convert RGBA to RGB with white background | |
| images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths] | |
| if self.transform: | |
| images = [self.transform(image) for image in images] | |
| else: | |
| closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0 | |
| closest_size = tuple(map(int, closest_size)) | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| CenterCropResizeImage(closest_size), | |
| T.Normalize([.5], [.5]), | |
| ]) | |
| images = [transform(image) for image in images] | |
| images = torch.stack(images) | |
| c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices] | |
| c2ws = torch.tensor(c2ws).reshape(-1, 4, 4) | |
| # max_distance = json_data.get('max_distance', self.default_max_distance) | |
| # if 'max_distance' not in json_data.keys(): | |
| # print(f"not found `max_distance` in json path: {json_path}") | |
| if self.swap_xz: | |
| swap_xz = torch.tensor([[[0, 0, 1., 0], | |
| [0, 1., 0, 0], | |
| [-1., 0, 0, 0], | |
| [0, 0, 0, 1.]]]) | |
| c2ws = swap_xz @ c2ws | |
| # OPENGL to OPENCV | |
| c2ws[:, 0:3, 1:3] *= -1 | |
| c2ws = c2ws[:, [1, 0, 2, 3], :] | |
| c2ws[:, 2, :] *= -1 | |
| w2cs = torch.inverse(c2ws) | |
| K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1) | |
| Rs = w2cs[:, :3, :3] | |
| Ts = w2cs[:, :3, 3] | |
| sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1) | |
| # get ray embedding and padding last dimension to 16 (num channels of VAE) | |
| # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling) | |
| rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling) | |
| rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6) | |
| # padding = (0, 10) # pad the last dimension to 16 | |
| # rays = torch.nn.functional.pad(rays, padding, "constant", 0) | |
| rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658 | |
| if os.path.exists(caption_path): | |
| with open(caption_path, 'r') as f: | |
| caption_key = random.choice(self.caption_keys) | |
| caption = json.load(f).get(caption_key, "") | |
| else: | |
| caption = "" | |
| caption = "[[multiview]] " + caption if caption else "[[multiview]]" | |
| return { | |
| 'pixel_values': images, | |
| 'rays': rays, | |
| 'aspect_ratio': closest_ratio, | |
| 'caption': caption, | |
| 'height': dh, | |
| 'width': dw, | |
| # 'origins': rays_od[..., :3], | |
| # 'dirs': rays_od[..., 3:6] | |
| } | |
| except Exception as e: | |
| return self.__getitem__(random.randint(0, len(self.scene_folders) - 1)) | |