File size: 4,370 Bytes
e5b70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import torchvision

import math
import cv2
import numpy as np
from scipy.ndimage import rotate


class RandCrop(object):
    def __init__(self, crop_size, scale):
        # if output size is tuple -> (height, width)
        assert isinstance(crop_size, (int, tuple))
        if isinstance(crop_size, int):
            self.crop_size = (crop_size, crop_size)
        else:
            assert len(crop_size) == 2
            self.crop_size = crop_size
        
        self.scale = scale

    def __call__(self, sample):
        # img_LQ: H x W x C (numpy array)
        img_LQ, img_GT = sample['img_LQ'], sample['img_GT']

        h, w, c = img_LQ.shape
        new_h, new_w = self.crop_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :]

        h, w, c = img_GT.shape
        top = np.random.randint(0, h - self.scale*new_h)
        left = np.random.randint(0, w - self.scale*new_w)
        img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :]

        sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop}
        return sample


class RandRotate(object):
    def __call__(self, sample):
        # img_LQ: H x W x C (numpy array)
        img_LQ, img_GT = sample['img_LQ'], sample['img_GT']

        prob_rotate = np.random.random()
        if prob_rotate < 0.25:
            img_LQ = rotate(img_LQ, 90).copy()
            img_GT = rotate(img_GT, 90).copy()
        elif prob_rotate < 0.5:
            img_LQ = rotate(img_LQ, 90).copy()
            img_GT = rotate(img_GT, 90).copy()
        elif prob_rotate < 0.75:
            img_LQ = rotate(img_LQ, 90).copy()
            img_GT = rotate(img_GT, 90).copy()
        
        sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
        return sample


class RandHorizontalFlip(object):
    def __call__(self, sample):
        # img_LQ: H x W x C (numpy array)
        img_LQ, img_GT = sample['img_LQ'], sample['img_GT']

        prob_lr = np.random.random()
        if prob_lr < 0.5:
            img_LQ = np.fliplr(img_LQ).copy()
            img_GT = np.fliplr(img_GT).copy()
        
        sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
        return sample


class ToTensor(object):
    def __call__(self, sample):
        # img_LQ : H x W x C (numpy array) -> C x H x W (torch tensor)
        img_LQ, img_GT = sample['img_LQ'], sample['img_GT']

        img_LQ = img_LQ.transpose((2, 0, 1))
        img_GT = img_GT.transpose((2, 0, 1))

        img_LQ = torch.from_numpy(img_LQ)
        img_GT = torch.from_numpy(img_GT)

        sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
        return sample


class VGG19PerceptualLoss(torch.nn.Module):
    def __init__(self, feature_layer=35):
        super(VGG19PerceptualLoss, self).__init__()
        model = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
        self.features = torch.nn.Sequential(*list(model.features.children())[:feature_layer]).eval()
        # Freeze parameters
        for name, param in self.features.named_parameters():
            param.requires_grad = False
    
    def forward(self, source, target):
        vgg_loss = torch.nn.functional.l1_loss(self.features(source), self.features(target))

        return vgg_loss
        

class RandCrop_pair(object):
    def __init__(self, crop_size, scale):
        # if output size is tuple -> (height, width)
        assert isinstance(crop_size, (int, tuple))
        if isinstance(crop_size, int):
            self.crop_size = (crop_size, crop_size)
        else:
            assert len(crop_size) == 2
            self.crop_size = crop_size
        
        self.scale = scale

    def __call__(self, sample):
        # img_LQ: H x W x C (numpy array)
        img_LQ, img_GT = sample['img_LQ'], sample['img_GT']

        h, w, c = img_LQ.shape
        new_h, new_w = self.crop_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :]

        h, w, c = img_GT.shape
        top = self.scale*top
        left = self.scale*left
        img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :]

        sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop}
        return sample