File size: 3,363 Bytes
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid
from torchvision.transforms import ToTensor

import numpy as np
import cv2
import glob
import random
from PIL import Image
from tqdm import tqdm


# from degradation.degradation_main import degredate_process, preparation
from opt import opt


class ImageDataset(Dataset):
    @torch.no_grad()
    def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths):
        # print("low_res path sample is ", train_lr_paths[0])
        # print(train_hr_paths[0])
        # hr_height, hr_width = hr_shape
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )

        self.files_lr = train_lr_paths
        self.files_degrade_hr = degrade_hr_paths
        self.files_hr = train_hr_paths

        assert(len(self.files_lr) == len(self.files_hr))
        assert(len(self.files_lr) == len(self.files_degrade_hr))


    def augment(self, imgs, hflip=True, rotation=True):
        """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).

        All the images in the list use the same augmentation.

        Args:
            imgs (list[ndarray] | ndarray): Images to be augmented. If the input
                is an ndarray, it will be transformed to a list.
            hflip (bool): Horizontal flip. Default: True.
            rotation (bool): Rotation. Default: True.

        Returns:
            imgs (list[ndarray] | ndarray): Augmented images and flows. If returned
                results only have one element, just return ndarray.

        """
        hflip = hflip and random.random() < 0.5
        vflip = rotation and random.random() < 0.5
        rot90 = rotation and random.random() < 0.5

        def _augment(img):
            if hflip:  # horizontal
                cv2.flip(img, 1, img)
            if vflip:  # vertical
                cv2.flip(img, 0, img)
            if rot90:
                img = img.transpose(1, 0, 2)
            return img


        if not isinstance(imgs, list):
            imgs = [imgs]
        
        imgs = [_augment(img) for img in imgs]
        if len(imgs) == 1:
            imgs = imgs[0]


        return imgs
            

    def __getitem__(self, index):
        
        # Read File
        img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) # Should be BGR
        img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)]) 
        img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)])

        # Augmentation
        if random.random() < opt["augment_prob"]:
            img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr])
        
        # Transform to Tensor
        img_lr = self.transform(img_lr)
        img_degrade_hr = self.transform(img_degrade_hr)
        img_hr = self.transform(img_hr)  # ToTensor() is already in the range [0, 1]


        return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr}
    
    def __len__(self):
        assert(len(self.files_hr) == len(self.files_lr))
        return len(self.files_hr)