text2live / datasets /image_dataset.py
SupermanxKiaski's picture
Upload 3 files
8366707
raw
history blame
No virus
1.75 kB
import os.path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import os
import os.path
import torch
from util.aug_utils import RandomScale, RandomSizeCrop, DivisibleCrop
class SingleImageDataset(Dataset):
def __init__(self, cfg):
self.cfg = cfg
self.base_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: transforms.ToTensor()(x).unsqueeze(0)),
DivisibleCrop(cfg["d_divisible_crops"]),
]
)
# used to create the internal dataset
self.input_transforms = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)],
p=cfg["jitter_p"],
),
transforms.RandomHorizontalFlip(p=cfg["flip_p"]),
RandomScale((cfg["scale_min"], cfg["scale_max"])),
RandomSizeCrop(cfg["crops_min_cover"]),
self.base_transforms,
]
)
# open source image
self.src_img = Image.open(cfg["image_path"]).convert("RGB")
if cfg["resize_input"] > 0:
self.src_img = transforms.Resize(cfg["resize_input"])(self.src_img)
self.step = -1
def get_img(self):
return self.base_transforms(self.src_img)
def __getitem__(self, index):
self.step += 1
sample = {"step": self.step}
if self.step % self.cfg["source_image_every"] == 0:
sample["input_image"] = self.get_img()
sample["input_crop"] = self.input_transforms(self.src_img)
return sample
def __len__(self):
return 1