text2live / Text2LIVE-main /datasets /image_dataset.py
SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
1.75 kB
import os.path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import os
import os.path
import torch
from util.aug_utils import RandomScale, RandomSizeCrop, DivisibleCrop
class SingleImageDataset(Dataset):
def __init__(self, cfg):
self.cfg = cfg
self.base_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: transforms.ToTensor()(x).unsqueeze(0)),
DivisibleCrop(cfg["d_divisible_crops"]),
]
)
# used to create the internal dataset
self.input_transforms = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)],
p=cfg["jitter_p"],
),
transforms.RandomHorizontalFlip(p=cfg["flip_p"]),
RandomScale((cfg["scale_min"], cfg["scale_max"])),
RandomSizeCrop(cfg["crops_min_cover"]),
self.base_transforms,
]
)
# open source image
self.src_img = Image.open(cfg["image_path"]).convert("RGB")
if cfg["resize_input"] > 0:
self.src_img = transforms.Resize(cfg["resize_input"])(self.src_img)
self.step = -1
def get_img(self):
return self.base_transforms(self.src_img)
def __getitem__(self, index):
self.step += 1
sample = {"step": self.step}
if self.step % self.cfg["source_image_every"] == 0:
sample["input_image"] = self.get_img()
sample["input_crop"] = self.input_transforms(self.src_img)
return sample
def __len__(self):
return 1