import os.path from PIL import Image from torch.utils.data import Dataset from torchvision import transforms import os import os.path import torch from util.aug_utils import RandomScale, RandomSizeCrop, DivisibleCrop class SingleImageDataset(Dataset): def __init__(self, cfg): self.cfg = cfg self.base_transforms = transforms.Compose( [ transforms.Lambda(lambda x: transforms.ToTensor()(x).unsqueeze(0)), DivisibleCrop(cfg["d_divisible_crops"]), ] ) # used to create the internal dataset self.input_transforms = transforms.Compose( [ transforms.RandomApply( [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], p=cfg["jitter_p"], ), transforms.RandomHorizontalFlip(p=cfg["flip_p"]), RandomScale((cfg["scale_min"], cfg["scale_max"])), RandomSizeCrop(cfg["crops_min_cover"]), self.base_transforms, ] ) # open source image self.src_img = Image.open(cfg["image_path"]).convert("RGB") if cfg["resize_input"] > 0: self.src_img = transforms.Resize(cfg["resize_input"])(self.src_img) self.step = -1 def get_img(self): return self.base_transforms(self.src_img) def __getitem__(self, index): self.step += 1 sample = {"step": self.step} if self.step % self.cfg["source_image_every"] == 0: sample["input_image"] = self.get_img() sample["input_crop"] = self.input_transforms(self.src_img) return sample def __len__(self): return 1