File size: 1,751 Bytes
16d007c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os.path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import os
import os.path
import torch

from util.aug_utils import RandomScale, RandomSizeCrop, DivisibleCrop


class SingleImageDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg

        self.base_transforms = transforms.Compose(
            [
                transforms.Lambda(lambda x: transforms.ToTensor()(x).unsqueeze(0)),
                DivisibleCrop(cfg["d_divisible_crops"]),
            ]
        )

        # used to create the internal dataset
        self.input_transforms = transforms.Compose(
            [
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)],
                    p=cfg["jitter_p"],
                ),
                transforms.RandomHorizontalFlip(p=cfg["flip_p"]),
                RandomScale((cfg["scale_min"], cfg["scale_max"])),
                RandomSizeCrop(cfg["crops_min_cover"]),
                self.base_transforms,
            ]
        )

        # open source image
        self.src_img = Image.open(cfg["image_path"]).convert("RGB")

        if cfg["resize_input"] > 0:
            self.src_img = transforms.Resize(cfg["resize_input"])(self.src_img)

        self.step = -1

    def get_img(self):
        return self.base_transforms(self.src_img)

    def __getitem__(self, index):
        self.step += 1
        sample = {"step": self.step}
        if self.step % self.cfg["source_image_every"] == 0:
            sample["input_image"] = self.get_img()

        sample["input_crop"] = self.input_transforms(self.src_img)

        return sample

    def __len__(self):
        return 1