import numpy as np import torch from PIL import Image from torch.utils.data.dataset import Dataset from utils.utils import cvtColor, preprocess_input class CycleGanDataset(Dataset): def __init__(self, annotation_lines_A, annotation_lines_B, input_shape): super(CycleGanDataset, self).__init__() self.annotation_lines_A = annotation_lines_A self.annotation_lines_B = annotation_lines_B self.length_A = len(self.annotation_lines_A) self.length_B = len(self.annotation_lines_B) self.input_shape = input_shape def __len__(self): return max(self.length_A, self.length_B) def __getitem__(self, index): index_A = index % self.length_A image_A = Image.open(self.annotation_lines_A[index_A].split(';')[1].split()[0]) image_A = cvtColor(image_A).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC) image_A = np.array(image_A, dtype=np.float32) image_A = np.transpose(preprocess_input(image_A), (2, 0, 1)) index_B = index % self.length_B image_B = Image.open(self.annotation_lines_B[index_B].split(';')[1].split()[0]) image_B = cvtColor(image_B).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC) image_B = np.array(image_B, dtype=np.float32) image_B = np.transpose(preprocess_input(image_B), (2, 0, 1)) return image_A, image_B def CycleGan_dataset_collate(batch): images_A = [] images_B = [] for image_A, image_B in batch: images_A.append(image_A) images_B.append(image_B) images_A = torch.from_numpy(np.array(images_A, np.float32)) images_B = torch.from_numpy(np.array(images_B, np.float32)) return images_A, images_B