import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageOps from torch.utils.data import Dataset from torchvision import transforms from diffusionsfm.dataset.co3d_v2 import square_bbox class CustomDataset(Dataset): def __init__( self, image_list, ): self.images = [] for image_path in sorted(image_list): img = Image.open(image_path) img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation self.images.append(img) self.n = len(self.images) self.jitter_scale = [1, 1] self.jitter_trans = [0, 0] self.transform = transforms.Compose( [ transforms.ToTensor(), transforms.Resize(224), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) self.transform_for_vis = transforms.Compose( [ transforms.Resize(224), ] ) def __len__(self): return 1 def _crop_image(self, image, bbox, white_bg=False): if white_bg: # Only support PIL Images image_crop = Image.new( "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255) ) image_crop.paste(image, (-bbox[0], -bbox[1])) else: image_crop = transforms.functional.crop( image, top=bbox[1], left=bbox[0], height=bbox[3] - bbox[1], width=bbox[2] - bbox[0], ) return image_crop def __getitem__(self): return self.get_data() def get_data(self): cmap = plt.get_cmap("hsv") ids = [i for i in range(len(self.images))] images = [self.images[i] for i in ids] images_transformed = [] images_for_vis = [] crop_parameters = [] for i, image in enumerate(images): bbox = np.array([0, 0, image.width, image.height]) bbox = square_bbox(bbox, tight=True) bbox = np.around(bbox).astype(int) image = self._crop_image(image, bbox) images_transformed.append(self.transform(image)) image_for_vis = self.transform_for_vis(image) color_float = cmap(i / len(images)) color_rgb = tuple(int(255 * c) for c in color_float[:3]) image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb) images_for_vis.append(image_for_vis) width, height = image.size length = max(width, height) s = length / min(width, height) crop_center = (bbox[:2] + bbox[2:]) / 2 crop_center = crop_center + (length - np.array([width, height])) / 2 # convert to NDC cc = s - 2 * s * crop_center / length crop_width = 2 * s * (bbox[2] - bbox[0]) / length crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) crop_parameters.append(crop_params) images = images_transformed batch = {} batch["image"] = torch.stack(images) batch["image_for_vis"] = images_for_vis batch["n"] = len(images) batch["ind"] = torch.tensor(ids), batch["crop_parameters"] = torch.stack(crop_parameters) batch["distortion_parameters"] = torch.zeros(4) return batch