Spaces:
Runtime error
Runtime error
| import gzip | |
| import json | |
| import os.path as osp | |
| import random | |
| import socket | |
| import time | |
| import torch | |
| import warnings | |
| import numpy as np | |
| from PIL import Image, ImageFile | |
| from tqdm import tqdm | |
| from pytorch3d.renderer import PerspectiveCameras | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| from scipy import ndimage as nd | |
| from diffusionsfm.utils.distortion import distort_image | |
| HOSTNAME = socket.gethostname() | |
| CO3D_DIR = "../co3d_data" # update this | |
| CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations") | |
| CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d") | |
| order_path = osp.join( | |
| CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json" | |
| ) | |
| TRAINING_CATEGORIES = [ | |
| "apple", | |
| "backpack", | |
| "banana", | |
| "baseballbat", | |
| "baseballglove", | |
| "bench", | |
| "bicycle", | |
| "bottle", | |
| "bowl", | |
| "broccoli", | |
| "cake", | |
| "car", | |
| "carrot", | |
| "cellphone", | |
| "chair", | |
| "cup", | |
| "donut", | |
| "hairdryer", | |
| "handbag", | |
| "hydrant", | |
| "keyboard", | |
| "laptop", | |
| "microwave", | |
| "motorcycle", | |
| "mouse", | |
| "orange", | |
| "parkingmeter", | |
| "pizza", | |
| "plant", | |
| "stopsign", | |
| "teddybear", | |
| "toaster", | |
| "toilet", | |
| "toybus", | |
| "toyplane", | |
| "toytrain", | |
| "toytruck", | |
| "tv", | |
| "umbrella", | |
| "vase", | |
| "wineglass", | |
| ] | |
| TEST_CATEGORIES = [ | |
| "ball", | |
| "book", | |
| "couch", | |
| "frisbee", | |
| "hotdog", | |
| "kite", | |
| "remote", | |
| "sandwich", | |
| "skateboard", | |
| "suitcase", | |
| ] | |
| assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51 | |
| Image.MAX_IMAGE_PIXELS = None | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| def fill_depths(data, invalid=None): | |
| data_list = [] | |
| for i in range(data.shape[0]): | |
| data_item = data[i].numpy() | |
| # Invalid must be 1 where stuff is invalid, 0 where valid | |
| ind = nd.distance_transform_edt( | |
| invalid[i], return_distances=False, return_indices=True | |
| ) | |
| data_list.append(torch.tensor(data_item[tuple(ind)])) | |
| return torch.stack(data_list, dim=0) | |
| def full_scene_scale(batch): | |
| cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda") | |
| cc = cameras.get_camera_center() | |
| centroid = torch.mean(cc, dim=0) | |
| diffs = cc - centroid | |
| norms = torch.linalg.norm(diffs, dim=1) | |
| furthest_index = torch.argmax(norms).item() | |
| scale = norms[furthest_index].item() | |
| return scale | |
| def square_bbox(bbox, padding=0.0, astype=None, tight=False): | |
| """ | |
| Computes a square bounding box, with optional padding parameters. | |
| Args: | |
| bbox: Bounding box in xyxy format (4,). | |
| Returns: | |
| square_bbox in xyxy format (4,). | |
| """ | |
| if astype is None: | |
| astype = type(bbox[0]) | |
| bbox = np.array(bbox) | |
| center = (bbox[:2] + bbox[2:]) / 2 | |
| extents = (bbox[2:] - bbox[:2]) / 2 | |
| # No black bars if tight | |
| if tight: | |
| s = min(extents) * (1 + padding) | |
| else: | |
| s = max(extents) * (1 + padding) | |
| square_bbox = np.array( | |
| [center[0] - s, center[1] - s, center[0] + s, center[1] + s], | |
| dtype=astype, | |
| ) | |
| return square_bbox | |
| def unnormalize_image(image, return_numpy=True, return_int=True): | |
| if isinstance(image, torch.Tensor): | |
| image = image.detach().cpu().numpy() | |
| if image.ndim == 3: | |
| if image.shape[0] == 3: | |
| image = image[None, ...] | |
| elif image.shape[2] == 3: | |
| image = image.transpose(2, 0, 1)[None, ...] | |
| else: | |
| raise ValueError(f"Unexpected image shape: {image.shape}") | |
| elif image.ndim == 4: | |
| if image.shape[1] == 3: | |
| pass | |
| elif image.shape[3] == 3: | |
| image = image.transpose(0, 3, 1, 2) | |
| else: | |
| raise ValueError(f"Unexpected batch image shape: {image.shape}") | |
| else: | |
| raise ValueError(f"Unsupported input shape: {image.shape}") | |
| mean = np.array([0.485, 0.456, 0.406])[None, :, None, None] | |
| std = np.array([0.229, 0.224, 0.225])[None, :, None, None] | |
| image = image * std + mean | |
| if return_int: | |
| image = np.clip(image * 255.0, 0, 255).astype(np.uint8) | |
| else: | |
| image = np.clip(image, 0.0, 1.0) | |
| if image.shape[0] == 1: | |
| image = image[0] | |
| if return_numpy: | |
| return image | |
| else: | |
| return torch.from_numpy(image) | |
| def unnormalize_image_for_vis(image): | |
| assert len(image.shape) == 5 and image.shape[2] == 3 | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device) | |
| image = image * std + mean | |
| image = (image - 0.5) / 0.5 | |
| return image | |
| def _transform_intrinsic(image, bbox, principal_point, focal_length): | |
| # Rescale intrinsics to match bbox | |
| half_box = np.array([image.width, image.height]).astype(np.float32) / 2 | |
| org_scale = min(half_box).astype(np.float32) | |
| # Pixel coordinates | |
| principal_point_px = half_box - (np.array(principal_point) * org_scale) | |
| focal_length_px = np.array(focal_length) * org_scale | |
| principal_point_px -= bbox[:2] | |
| new_bbox = (bbox[2:] - bbox[:2]) / 2 | |
| new_scale = min(new_bbox) | |
| # NDC coordinates | |
| new_principal_ndc = (new_bbox - principal_point_px) / new_scale | |
| new_focal_ndc = focal_length_px / new_scale | |
| principal_point = torch.tensor(new_principal_ndc.astype(np.float32)) | |
| focal_length = torch.tensor(new_focal_ndc.astype(np.float32)) | |
| return principal_point, focal_length | |
| def construct_camera_from_batch(batch, device): | |
| if isinstance(device, int): | |
| device = f"cuda:{device}" | |
| return PerspectiveCameras( | |
| R=batch["R"].reshape(-1, 3, 3), | |
| T=batch["T"].reshape(-1, 3), | |
| focal_length=batch["focal_lengths"].reshape(-1, 2), | |
| principal_point=batch["principal_points"].reshape(-1, 2), | |
| image_size=batch["image_sizes"].reshape(-1, 2), | |
| device=device, | |
| ) | |
| def save_batch_images(images, fname): | |
| cmap = plt.get_cmap("hsv") | |
| num_frames = len(images) | |
| num_rows = len(images) | |
| num_cols = 4 | |
| figsize = (num_cols * 2, num_rows * 2) | |
| fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) | |
| axs = axs.flatten() | |
| for i in range(num_rows): | |
| for j in range(4): | |
| if i < num_frames: | |
| axs[i * 4 + j].imshow(unnormalize_image(images[i][j])) | |
| for s in ["bottom", "top", "left", "right"]: | |
| axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames))) | |
| axs[i * 4 + j].spines[s].set_linewidth(5) | |
| axs[i * 4 + j].set_xticks([]) | |
| axs[i * 4 + j].set_yticks([]) | |
| else: | |
| axs[i * 4 + j].axis("off") | |
| plt.tight_layout() | |
| plt.savefig(fname) | |
| def jitter_bbox( | |
| square_bbox, | |
| jitter_scale=(1.1, 1.2), | |
| jitter_trans=(-0.07, 0.07), | |
| direction_from_size=None, | |
| ): | |
| square_bbox = np.array(square_bbox.astype(float)) | |
| s = np.random.uniform(jitter_scale[0], jitter_scale[1]) | |
| # Jitter only one dimension if center cropping | |
| tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2) | |
| if direction_from_size is not None: | |
| if direction_from_size[0] > direction_from_size[1]: | |
| tx = 0 | |
| else: | |
| ty = 0 | |
| side_length = square_bbox[2] - square_bbox[0] | |
| center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length | |
| extent = side_length / 2 * s | |
| ul = center - extent | |
| lr = ul + 2 * extent | |
| return np.concatenate((ul, lr)) | |
| class Co3dDataset(Dataset): | |
| def __init__( | |
| self, | |
| category=("all_train",), | |
| split="train", | |
| transform=None, | |
| num_images=2, | |
| img_size=224, | |
| mask_images=False, | |
| crop_images=True, | |
| co3d_dir=None, | |
| co3d_annotation_dir=None, | |
| precropped_images=False, | |
| apply_augmentation=True, | |
| normalize_cameras=True, | |
| no_images=False, | |
| sample_num=None, | |
| seed=0, | |
| load_extra_cameras=False, | |
| distort_image=False, | |
| load_depths=False, | |
| center_crop=False, | |
| depth_size=256, | |
| mask_holes=False, | |
| object_mask=True, | |
| ): | |
| """ | |
| Args: | |
| num_images: Number of images in each batch. | |
| perspective_correction (str): | |
| "none": No perspective correction. | |
| "warp": Warp the image and label. | |
| "label_only": Correct the label only. | |
| """ | |
| start_time = time.time() | |
| self.category = category | |
| self.split = split | |
| self.transform = transform | |
| self.num_images = num_images | |
| self.img_size = img_size | |
| self.mask_images = mask_images | |
| self.crop_images = crop_images | |
| self.precropped_images = precropped_images | |
| self.apply_augmentation = apply_augmentation | |
| self.normalize_cameras = normalize_cameras | |
| self.no_images = no_images | |
| self.sample_num = sample_num | |
| self.load_extra_cameras = load_extra_cameras | |
| self.distort = distort_image | |
| self.load_depths = load_depths | |
| self.center_crop = center_crop | |
| self.depth_size = depth_size | |
| self.mask_holes = mask_holes | |
| self.object_mask = object_mask | |
| if self.apply_augmentation: | |
| if self.center_crop: | |
| self.jitter_scale = (0.8, 1.1) | |
| self.jitter_trans = (0.0, 0.0) | |
| else: | |
| self.jitter_scale = (1.1, 1.2) | |
| self.jitter_trans = (-0.07, 0.07) | |
| else: | |
| # Note if trained with apply_augmentation, we should still use | |
| # apply_augmentation at test time. | |
| self.jitter_scale = (1, 1) | |
| self.jitter_trans = (0.0, 0.0) | |
| if self.distort: | |
| self.k1_max = 1.0 | |
| self.k2_max = 1.0 | |
| if co3d_dir is not None: | |
| self.co3d_dir = co3d_dir | |
| self.co3d_annotation_dir = co3d_annotation_dir | |
| else: | |
| self.co3d_dir = CO3D_DIR | |
| self.co3d_annotation_dir = CO3D_ANNOTATION_DIR | |
| self.co3d_depth_dir = CO3D_DEPTH_DIR | |
| if isinstance(self.category, str): | |
| self.category = [self.category] | |
| if "all_train" in self.category: | |
| self.category = TRAINING_CATEGORIES | |
| if "all_test" in self.category: | |
| self.category = TEST_CATEGORIES | |
| if "full" in self.category: | |
| self.category = TRAINING_CATEGORIES + TEST_CATEGORIES | |
| self.category = sorted(self.category) | |
| self.is_single_category = len(self.category) == 1 | |
| # Fixing seed | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| print(f"Co3d ({split}):") | |
| self.low_quality_translations = [ | |
| "411_55952_107659", | |
| "427_59915_115716", | |
| "435_61970_121848", | |
| "112_13265_22828", | |
| "110_13069_25642", | |
| "165_18080_34378", | |
| "368_39891_78502", | |
| "391_47029_93665", | |
| "20_695_1450", | |
| "135_15556_31096", | |
| "417_57572_110680", | |
| ] # Initialized with sequences with poor depth masks | |
| self.rotations = {} | |
| self.category_map = {} | |
| for c in tqdm(self.category): | |
| annotation_file = osp.join( | |
| self.co3d_annotation_dir, f"{c}_{self.split}.jgz" | |
| ) | |
| with gzip.open(annotation_file, "r") as fin: | |
| annotation = json.loads(fin.read()) | |
| counter = 0 | |
| for seq_name, seq_data in annotation.items(): | |
| counter += 1 | |
| if len(seq_data) < self.num_images: | |
| continue | |
| filtered_data = [] | |
| self.category_map[seq_name] = c | |
| bad_seq = False | |
| for data in seq_data: | |
| # Make sure translations are not ridiculous and rotations are valid | |
| det = np.linalg.det(data["R"]) | |
| if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01: | |
| bad_seq = True | |
| self.low_quality_translations.append(seq_name) | |
| break | |
| # Ignore all unnecessary information. | |
| filtered_data.append( | |
| { | |
| "filepath": data["filepath"], | |
| "bbox": data["bbox"], | |
| "R": data["R"], | |
| "T": data["T"], | |
| "focal_length": data["focal_length"], | |
| "principal_point": data["principal_point"], | |
| }, | |
| ) | |
| if not bad_seq: | |
| self.rotations[seq_name] = filtered_data | |
| self.sequence_list = list(self.rotations.keys()) | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| if self.transform is None: | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Resize(self.img_size, antialias=True), | |
| transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
| ] | |
| ) | |
| self.transform_depth = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| self.depth_size, | |
| antialias=False, | |
| interpolation=transforms.InterpolationMode.NEAREST_EXACT, | |
| ), | |
| ] | |
| ) | |
| print( | |
| f"Low quality translation sequences, not used: {self.low_quality_translations}" | |
| ) | |
| print(f"Data size: {len(self)}") | |
| print(f"Data loading took {(time.time()-start_time)} seconds.") | |
| def __len__(self): | |
| return len(self.sequence_list) | |
| def __getitem__(self, index): | |
| num_to_load = self.num_images if not self.load_extra_cameras else 8 | |
| sequence_name = self.sequence_list[index % len(self.sequence_list)] | |
| metadata = self.rotations[sequence_name] | |
| if self.sample_num is not None: | |
| with open( | |
| order_path.format(sample_num=self.sample_num, category=self.category[0]) | |
| ) as f: | |
| order = json.load(f) | |
| ids = order[sequence_name][:num_to_load] | |
| else: | |
| replace = len(metadata) < 8 | |
| ids = np.random.choice(len(metadata), num_to_load, replace=replace) | |
| return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load) | |
| def _get_scene_scale(self, sequence_name): | |
| n = len(self.rotations[sequence_name]) | |
| R = torch.zeros(n, 3, 3) | |
| T = torch.zeros(n, 3) | |
| for i, ann in enumerate(self.rotations[sequence_name]): | |
| R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"]) | |
| T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"]) | |
| cameras = PerspectiveCameras(R=R, T=T) | |
| cc = cameras.get_camera_center() | |
| centeroid = torch.mean(cc, dim=0) | |
| diff = cc - centeroid | |
| norm = torch.norm(diff, dim=1) | |
| scale = torch.max(norm).item() | |
| return scale | |
| def _crop_image(self, image, bbox): | |
| image_crop = transforms.functional.crop( | |
| image, | |
| top=bbox[1], | |
| left=bbox[0], | |
| height=bbox[3] - bbox[1], | |
| width=bbox[2] - bbox[0], | |
| ) | |
| return image_crop | |
| def _transform_intrinsic(self, image, bbox, principal_point, focal_length): | |
| half_box = np.array([image.width, image.height]).astype(np.float32) / 2 | |
| org_scale = min(half_box).astype(np.float32) | |
| # Pixel coordinates | |
| principal_point_px = half_box - (np.array(principal_point) * org_scale) | |
| focal_length_px = np.array(focal_length) * org_scale | |
| principal_point_px -= bbox[:2] | |
| new_bbox = (bbox[2:] - bbox[:2]) / 2 | |
| new_scale = min(new_bbox) | |
| # NDC coordinates | |
| new_principal_ndc = (new_bbox - principal_point_px) / new_scale | |
| new_focal_ndc = focal_length_px / new_scale | |
| return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32) | |
| def get_data( | |
| self, | |
| index=None, | |
| sequence_name=None, | |
| ids=(0, 1), | |
| no_images=False, | |
| num_valid_frames=None, | |
| load_using_order=None, | |
| ): | |
| if load_using_order is not None: | |
| with open( | |
| order_path.format(sample_num=self.sample_num, category=self.category[0]) | |
| ) as f: | |
| order = json.load(f) | |
| ids = order[sequence_name][:load_using_order] | |
| if sequence_name is None: | |
| index = index % len(self.sequence_list) | |
| sequence_name = self.sequence_list[index] | |
| metadata = self.rotations[sequence_name] | |
| category = self.category_map[sequence_name] | |
| # Read image & camera information from annotations | |
| annos = [metadata[i] for i in ids] | |
| images = [] | |
| image_sizes = [] | |
| PP = [] | |
| FL = [] | |
| crop_parameters = [] | |
| filenames = [] | |
| distortion_parameters = [] | |
| depths = [] | |
| depth_masks = [] | |
| object_masks = [] | |
| dino_images = [] | |
| for anno in annos: | |
| filepath = anno["filepath"] | |
| if not no_images: | |
| image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB") | |
| image_size = image.size | |
| # Optionally mask images with black background | |
| if self.mask_images: | |
| black_image = Image.new("RGB", image_size, (0, 0, 0)) | |
| mask_name = osp.basename(filepath.replace(".jpg", ".png")) | |
| mask_path = osp.join( | |
| self.co3d_dir, category, sequence_name, "masks", mask_name | |
| ) | |
| mask = Image.open(mask_path).convert("L") | |
| if mask.size != image_size: | |
| mask = mask.resize(image_size) | |
| mask = Image.fromarray(np.array(mask) > 125) | |
| image = Image.composite(image, black_image, mask) | |
| if self.object_mask: | |
| mask_name = osp.basename(filepath.replace(".jpg", ".png")) | |
| mask_path = osp.join( | |
| self.co3d_dir, category, sequence_name, "masks", mask_name | |
| ) | |
| mask = Image.open(mask_path).convert("L") | |
| if mask.size != image_size: | |
| mask = mask.resize(image_size) | |
| mask = torch.from_numpy(np.array(mask) > 125) | |
| # Determine crop, Resnet wants square images | |
| bbox = np.array(anno["bbox"]) | |
| good_bbox = ((bbox[2:] - bbox[:2]) > 30).all() | |
| bbox = ( | |
| anno["bbox"] | |
| if not self.center_crop and good_bbox | |
| else [0, 0, image.width, image.height] | |
| ) | |
| # Distort image and bbox if desired | |
| if self.distort: | |
| k1 = random.uniform(0, self.k1_max) | |
| k2 = random.uniform(0, self.k2_max) | |
| try: | |
| image, bbox = distort_image( | |
| image, np.array(bbox), k1, k2, modify_bbox=True | |
| ) | |
| except: | |
| print("INFO:") | |
| print(sequence_name) | |
| print(index) | |
| print(ids) | |
| print(k1) | |
| print(k2) | |
| distortion_parameters.append(torch.FloatTensor([k1, k2])) | |
| bbox = square_bbox(np.array(bbox), tight=self.center_crop) | |
| if self.apply_augmentation: | |
| bbox = jitter_bbox( | |
| bbox, | |
| jitter_scale=self.jitter_scale, | |
| jitter_trans=self.jitter_trans, | |
| direction_from_size=image.size if self.center_crop else None, | |
| ) | |
| bbox = np.around(bbox).astype(int) | |
| # Crop parameters | |
| crop_center = (bbox[:2] + bbox[2:]) / 2 | |
| principal_point = torch.tensor(anno["principal_point"]) | |
| focal_length = torch.tensor(anno["focal_length"]) | |
| # convert crop center to correspond to a "square" image | |
| width, height = image.size | |
| length = max(width, height) | |
| s = length / min(width, height) | |
| crop_center = crop_center + (length - np.array([width, height])) / 2 | |
| # convert to NDC | |
| cc = s - 2 * s * crop_center / length | |
| crop_width = 2 * s * (bbox[2] - bbox[0]) / length | |
| crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) | |
| # Crop and normalize image | |
| if not self.precropped_images: | |
| image = self._crop_image(image, bbox) | |
| try: | |
| image = self.transform(image) | |
| except: | |
| print("INFO:") | |
| print(sequence_name) | |
| print(index) | |
| print(ids) | |
| print(k1) | |
| print(k2) | |
| images.append(image[:, : self.img_size, : self.img_size]) | |
| crop_parameters.append(crop_params) | |
| if self.load_depths: | |
| # Open depth map | |
| depth_name = osp.basename( | |
| filepath.replace(".jpg", ".jpg.geometric.png") | |
| ) | |
| depth_path = osp.join( | |
| self.co3d_depth_dir, | |
| category, | |
| sequence_name, | |
| "depths", | |
| depth_name, | |
| ) | |
| depth_pil = Image.open(depth_path) | |
| # 16 bit float type casting | |
| depth = torch.tensor( | |
| np.frombuffer( | |
| np.array(depth_pil, dtype=np.uint16), dtype=np.float16 | |
| ) | |
| .astype(np.float32) | |
| .reshape((depth_pil.size[1], depth_pil.size[0])) | |
| ) | |
| # Crop and resize as with images | |
| if depth_pil.size != image_size: | |
| # bbox may have the wrong scale | |
| bbox = depth_pil.size[0] * bbox / image_size[0] | |
| if self.object_mask: | |
| assert mask.shape == depth.shape | |
| bbox = np.around(bbox).astype(int) | |
| depth = self._crop_image(depth, bbox) | |
| # Resize | |
| depth = self.transform_depth(depth.unsqueeze(0))[ | |
| 0, : self.depth_size, : self.depth_size | |
| ] | |
| depths.append(depth) | |
| if self.object_mask: | |
| mask = self._crop_image(mask, bbox) | |
| mask = self.transform_depth(mask.unsqueeze(0))[ | |
| 0, : self.depth_size, : self.depth_size | |
| ] | |
| object_masks.append(mask) | |
| PP.append(principal_point) | |
| FL.append(focal_length) | |
| image_sizes.append(torch.tensor([self.img_size, self.img_size])) | |
| filenames.append(filepath) | |
| if not no_images: | |
| if self.load_depths: | |
| depths = torch.stack(depths) | |
| depth_masks = torch.logical_or(depths <= 0, depths.isinf()) | |
| depth_masks = (~depth_masks).long() | |
| if self.object_mask: | |
| object_masks = torch.stack(object_masks, dim=0) | |
| if self.mask_holes: | |
| depths = fill_depths(depths, depth_masks == 0) | |
| # Sometimes mask_holes misses stuff | |
| new_masks = torch.logical_or(depths <= 0, depths.isinf()) | |
| new_masks = (~new_masks).long() | |
| depths[new_masks == 0] = -1 | |
| assert torch.logical_or(depths > 0, depths == -1).all() | |
| assert not (depths.isinf()).any() | |
| assert not (depths.isnan()).any() | |
| if self.load_extra_cameras: | |
| # Remove the extra loaded image, for saving space | |
| images = images[: self.num_images] | |
| if self.distort: | |
| distortion_parameters = torch.stack(distortion_parameters) | |
| images = torch.stack(images) | |
| crop_parameters = torch.stack(crop_parameters) | |
| focal_lengths = torch.stack(FL) | |
| principal_points = torch.stack(PP) | |
| image_sizes = torch.stack(image_sizes) | |
| else: | |
| images = None | |
| crop_parameters = None | |
| distortion_parameters = None | |
| focal_lengths = [] | |
| principal_points = [] | |
| image_sizes = [] | |
| # Assemble batch info to send back | |
| R = torch.stack([torch.tensor(anno["R"]) for anno in annos]) | |
| T = torch.stack([torch.tensor(anno["T"]) for anno in annos]) | |
| batch = { | |
| "model_id": sequence_name, | |
| "category": category, | |
| "n": len(metadata), | |
| "num_valid_frames": num_valid_frames, | |
| "ind": torch.tensor(ids), | |
| "image": images, | |
| "depth": depths, | |
| "depth_masks": depth_masks, | |
| "object_masks": object_masks, | |
| "R": R, | |
| "T": T, | |
| "focal_length": focal_lengths, | |
| "principal_point": principal_points, | |
| "image_size": image_sizes, | |
| "crop_parameters": crop_parameters, | |
| "distortion_parameters": torch.zeros(4), | |
| "filename": filenames, | |
| "category": category, | |
| "dataset": "co3d", | |
| } | |
| return batch | |