白鹭先生 commited on
Commit
3a0bab1
1 Parent(s): 98068b3
Files changed (2) hide show
  1. utils/dataloader.py +0 -101
  2. utils/utils.py +0 -100
utils/dataloader.py DELETED
@@ -1,101 +0,0 @@
1
- from random import randint
2
-
3
- import cv2
4
- import numpy as np
5
- from PIL import Image
6
- from torch.utils.data.dataset import Dataset
7
-
8
- from .utils import cvtColor, preprocess_input
9
-
10
- def look_image(image_name, image):
11
- image = np.array(image)
12
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
13
- cv2.imshow(image_name, image)
14
- cv2.waitKey(0)
15
-
16
-
17
- def get_new_img_size(width, height, img_min_side=600):
18
- if width <= height:
19
- f = float(img_min_side) / width
20
- resized_height = int(f * height)
21
- resized_width = int(img_min_side)
22
- else:
23
- f = float(img_min_side) / height
24
- resized_width = int(f * width)
25
- resized_height = int(img_min_side)
26
-
27
- return resized_width, resized_height
28
-
29
- class MASKGANDataset(Dataset):
30
- def __init__(self, train_lines, lr_shape, hr_shape):
31
- super(MASKGANDataset, self).__init__()
32
-
33
- self.train_lines = train_lines
34
- self.train_batches = len(train_lines)
35
-
36
- self.lr_shape = lr_shape
37
- self.hr_shape = hr_shape
38
-
39
- def __len__(self):
40
- return self.train_batches
41
-
42
- def __getitem__(self, index):
43
- index = index % self.train_batches
44
- image_list = self.train_lines[index].split(' ')
45
- image_origin = Image.open(image_list[0])
46
- image_masked = Image.open(image_list[1].split()[0])
47
-
48
- image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape)
49
-
50
- image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC)
51
- image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
52
- # look_image('origin', image_origin)
53
- # look_image('masked', image_masked)
54
- image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
55
- image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
56
-
57
- return np.array(image_masked), np.array(image_origin)
58
-
59
- def rand(self, a=0, b=1):
60
- return np.random.rand()*(b-a) + a
61
-
62
- def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
63
- #------------------------------#
64
- # 读取图像并转换成RGB图像
65
- #------------------------------#
66
- image_origin = cvtColor(image_origin)
67
- image_masked = cvtColor(image_masked)
68
-
69
- #------------------------------------------#
70
- # 色域扭曲
71
- #------------------------------------------#
72
- hue = self.rand(-hue, hue)
73
- sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
74
- val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
75
-
76
- x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV)
77
- x[..., 1] *= sat
78
- x[..., 2] *= val
79
- x[x[:,:, 0]>360, 0] = 360
80
- x[:, :, 1:][x[:, :, 1:]>1] = 1
81
- x[x<0] = 0
82
- image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
83
-
84
- x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV)
85
- x[..., 1] *= sat
86
- x[..., 2] *= val
87
- x[x[:,:, 0]>360, 0] = 360
88
- x[:, :, 1:][x[:, :, 1:]>1] = 1
89
- x[x<0] = 0
90
- image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
91
-
92
- return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked))
93
-
94
-
95
- def MASKGAN_dataset_collate(batch):
96
- images_l = []
97
- images_h = []
98
- for img_l, img_h in batch:
99
- images_l.append(img_l)
100
- images_h.append(img_h)
101
- return np.array(images_l), np.array(images_h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/utils.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  import matplotlib.pyplot as plt
4
  import torch
5
  from torch.nn import functional as F
6
- import cv2
7
  import distutils.util
8
 
9
  def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
@@ -63,102 +62,3 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
63
  help=help + ' 默认: %(default)s.',
64
  **kwargs)
65
 
66
- def filter2D(img, kernel):
67
- """PyTorch version of cv2.filter2D
68
-
69
- Args:
70
- img (Tensor): (b, c, h, w)
71
- kernel (Tensor): (b, k, k)
72
- """
73
- k = kernel.size(-1)
74
- b, c, h, w = img.size()
75
- if k % 2 == 1:
76
- img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
77
- else:
78
- raise ValueError('Wrong kernel size')
79
-
80
- ph, pw = img.size()[-2:]
81
-
82
- if kernel.size(0) == 1:
83
- # apply the same kernel to all batch images
84
- img = img.view(b * c, 1, ph, pw)
85
- kernel = kernel.view(1, 1, k, k)
86
- return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
87
- else:
88
- img = img.view(1, b * c, ph, pw)
89
- kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
90
- return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
91
-
92
-
93
- def usm_sharp(img, weight=0.5, radius=50, threshold=10):
94
- """USM sharpening.
95
-
96
- Input image: I; Blurry image: B.
97
- 1. sharp = I + weight * (I - B)
98
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
99
- 3. Blur mask:
100
- 4. Out = Mask * sharp + (1 - Mask) * I
101
-
102
-
103
- Args:
104
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
105
- weight (float): Sharp weight. Default: 1.
106
- radius (float): Kernel size of Gaussian blur. Default: 50.
107
- threshold (int):
108
- """
109
- if radius % 2 == 0:
110
- radius += 1
111
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
112
- residual = img - blur
113
- mask = np.abs(residual) * 255 > threshold
114
- mask = mask.astype('float32')
115
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
116
-
117
- sharp = img + weight * residual
118
- sharp = np.clip(sharp, 0, 1)
119
- return soft_mask * sharp + (1 - soft_mask) * img
120
-
121
-
122
- class USMSharp(torch.nn.Module):
123
-
124
- def __init__(self, radius=50, sigma=0):
125
- super(USMSharp, self).__init__()
126
- if radius % 2 == 0:
127
- radius += 1
128
- self.radius = radius
129
- kernel = cv2.getGaussianKernel(radius, sigma)
130
- kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
131
- self.register_buffer('kernel', kernel)
132
-
133
- def forward(self, img, weight=0.5, threshold=10):
134
- blur = filter2D(img, self.kernel)
135
- residual = img - blur
136
-
137
- mask = torch.abs(residual) * 255 > threshold
138
- mask = mask.float()
139
- soft_mask = filter2D(mask, self.kernel)
140
- sharp = img + weight * residual
141
- sharp = torch.clip(sharp, 0, 1)
142
- return soft_mask * sharp + (1 - soft_mask) * img
143
-
144
- class USMSharp_npy():
145
-
146
- def __init__(self, radius=50, sigma=0):
147
- super(USMSharp_npy, self).__init__()
148
- if radius % 2 == 0:
149
- radius += 1
150
- self.radius = radius
151
- kernel = cv2.getGaussianKernel(radius, sigma)
152
- self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32)
153
-
154
- def filt(self, img, weight=0.5, threshold=10):
155
- blur = cv2.filter2D(img, -1, self.kernel)
156
- residual = img - blur
157
-
158
- mask = np.abs(residual) * 255 > threshold
159
- mask = mask.astype(np.float32)
160
- soft_mask = cv2.filter2D(mask, -1, self.kernel)
161
- sharp = img + weight * residual
162
- sharp = np.clip(sharp, 0, 1)
163
- return soft_mask * sharp + (1 - soft_mask) * img
164
-
 
3
  import matplotlib.pyplot as plt
4
  import torch
5
  from torch.nn import functional as F
 
6
  import distutils.util
7
 
8
  def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
 
62
  help=help + ' 默认: %(default)s.',
63
  **kwargs)
64