from PIL import Image from torch.utils.data import Dataset, DataLoader from augmentations import ha_augment_sample, resize_sample, spatial_augment_sample from lanet_utils import to_tensor_sample def image_transforms(shape, jittering): def train_transforms(sample): sample = resize_sample(sample, image_shape=shape) sample = spatial_augment_sample(sample) sample = to_tensor_sample(sample) sample = ha_augment_sample(sample, jitter_paramters=jittering) return sample return {"train": train_transforms} class GetData(Dataset): def __init__(self, config, transforms=None): """ Get the list containing all images and labels. """ datafile = open(config.train_txt, "r") lines = datafile.readlines() dataset = [] for line in lines: line = line.rstrip() data = line.split() dataset.append(data[0]) self.config = config self.dataset = dataset self.root = config.train_root self.transforms = transforms def __getitem__(self, index): """ Return image'data and its label. """ img_path = self.dataset[index] img_file = self.root + img_path img = Image.open(img_file) # image.mode == 'L' means the image is in gray scale if img.mode == "L": img_new = Image.new("RGB", img.size) img_new.paste(img) sample = {"image": img_new, "idx": index} else: sample = {"image": img, "idx": index} if self.transforms: sample = self.transforms(sample) return sample def __len__(self): """ Return the number of all data. """ return len(self.dataset) def get_data_loader( config, transforms=None, sampler=None, drop_last=True, ): """ Return batch data for training. """ transforms = image_transforms(shape=config.image_shape, jittering=config.jittering) dataset = GetData(config, transforms=transforms["train"]) train_loader = DataLoader( dataset, batch_size=config.batch_size, shuffle=config.shuffle, sampler=sampler, num_workers=config.num_workers, pin_memory=config.pin_memory, drop_last=drop_last, ) return train_loader