白鹭先生 commited on
Commit
db5513e
1 Parent(s): 1e05415

新增SwinIR模型

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
  propress.py
 
1
  propress.py
2
+ __pycache__
app.py CHANGED
@@ -5,18 +5,17 @@ LastEditors: Egrt
5
  LastEditTime: 2022-01-13 13:48:57
6
  FilePath: \LicenseGAN\app.py
7
  '''
8
- import os
9
- os.system('pip install -r requirements.txt')
10
 
11
  from PIL import Image
 
12
  from esrgan import ESRGAN
13
  import gradio as gr
14
-
15
  esrgan = ESRGAN()
16
 
17
  # --------模型推理---------- #
18
  def inference(img):
19
- lr_shape = [12, 24]
20
  img = img.resize((lr_shape[1], lr_shape[0]), Image.BICUBIC)
21
  r_image = esrgan.generate_1x1_image(img)
22
  return r_image
5
  LastEditTime: 2022-01-13 13:48:57
6
  FilePath: \LicenseGAN\app.py
7
  '''
 
 
8
 
9
  from PIL import Image
10
+
11
  from esrgan import ESRGAN
12
  import gradio as gr
13
+ import os
14
  esrgan = ESRGAN()
15
 
16
  # --------模型推理---------- #
17
  def inference(img):
18
+ lr_shape = [32, 56]
19
  img = img.resize((lr_shape[1], lr_shape[0]), Image.BICUBIC)
20
  r_image = esrgan.generate_1x1_image(img)
21
  return r_image
esrgan.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
  import torch
3
  import torch.backends.cudnn as cudnn
4
  from PIL import Image
5
- from nets.esrgan import Generator
6
  from utils.utils import cvtColor, preprocess_input
7
 
8
 
@@ -14,11 +14,15 @@ class ESRGAN(object):
14
  #-----------------------------------------------#
15
  # model_path指向logs文件夹下的权值文件
16
  #-----------------------------------------------#
17
- "model_path" : 'model_data/Generator_ESRGAN6.pth',
18
  #-----------------------------------------------#
19
  # 上采样的倍数,和训练时一样
20
  #-----------------------------------------------#
21
- "scale_factor" : 8,
 
 
 
 
22
  #-------------------------------#
23
  # 是否使用Cuda
24
  # 没有GPU可以设置成False
@@ -36,7 +40,10 @@ class ESRGAN(object):
36
  self.generate()
37
 
38
  def generate(self):
39
- self.net = Generator(self.scale_factor)
 
 
 
40
 
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
  self.net.load_state_dict(torch.load(self.model_path, map_location=device))
@@ -72,7 +79,7 @@ class ESRGAN(object):
72
  # 将归一化的结果再转成rgb格式
73
  #---------------------------------------------------------#
74
  hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
75
- hr_image = (hr_image-np.min(hr_image))/(np.max(hr_image)-np.min(hr_image)) * 255
76
 
77
  hr_image = Image.fromarray(np.uint8(hr_image))
78
  return hr_image
2
  import torch
3
  import torch.backends.cudnn as cudnn
4
  from PIL import Image
5
+ from nets.SwinIR import Generator
6
  from utils.utils import cvtColor, preprocess_input
7
 
8
 
14
  #-----------------------------------------------#
15
  # model_path指向logs文件夹下的权值文件
16
  #-----------------------------------------------#
17
+ "model_path" : 'model_data/Generator_SwinIR.pth',
18
  #-----------------------------------------------#
19
  # 上采样的倍数,和训练时一样
20
  #-----------------------------------------------#
21
+ "scale_factor" : 4,
22
+ #-----------------------------------------------#
23
+ # hr_shape
24
+ #-----------------------------------------------#
25
+ "hr_shape" : [128, 224],
26
  #-------------------------------#
27
  # 是否使用Cuda
28
  # 没有GPU可以设置成False
40
  self.generate()
41
 
42
  def generate(self):
43
+ # self.net = Generator(self.scale_factor)
44
+ self.net = Generator(upscale=self.scale_factor, img_size=tuple(self.hr_shape),
45
+ window_size=8, img_range=1., depths=[3, 3, 3, 3],
46
+ embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=2, upsampler='pixelshuffledirect')
47
 
48
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
  self.net.load_state_dict(torch.load(self.model_path, map_location=device))
79
  # 将归一化的结果再转成rgb格式
80
  #---------------------------------------------------------#
81
  hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
82
+ hr_image = np.clip(hr_image * 255, 0, 255)
83
 
84
  hr_image = Image.fromarray(np.uint8(hr_image))
85
  return hr_image
model_data/{Generator_ESRGAN6.pth → Generator_SwinIR.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e9694a2253dba5c2e0365bd5d90842354d9edd46fa074509621cfcb1a54b34ae
3
- size 42002303
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dbb3371d937501b0fd913053d92a0c358ccbcb240cb133a319d1cd86dcbbfe9
3
+ size 32036063
nets/SwinIR.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
17
+ super().__init__()
18
+ out_features = out_features or in_features
19
+ hidden_features = hidden_features or in_features
20
+ self.fc1 = nn.Linear(in_features, hidden_features)
21
+ self.act = act_layer()
22
+ self.fc2 = nn.Linear(hidden_features, out_features)
23
+ self.drop = nn.Dropout(drop)
24
+
25
+ def forward(self, x):
26
+ x = self.fc1(x)
27
+ x = self.act(x)
28
+ x = self.drop(x)
29
+ x = self.fc2(x)
30
+ x = self.drop(x)
31
+ return x
32
+
33
+
34
+ def window_partition(x, window_size):
35
+ """
36
+ Args:
37
+ x: (B, H, W, C)
38
+ window_size (int): window size
39
+
40
+ Returns:
41
+ windows: (num_windows*B, window_size, window_size, C)
42
+ """
43
+ B, H, W, C = x.shape
44
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
45
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
46
+ return windows
47
+
48
+
49
+ def window_reverse(windows, window_size, H, W):
50
+ """
51
+ Args:
52
+ windows: (num_windows*B, window_size, window_size, C)
53
+ window_size (int): Window size
54
+ H (int): Height of image
55
+ W (int): Width of image
56
+
57
+ Returns:
58
+ x: (B, H, W, C)
59
+ """
60
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
61
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
62
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
63
+ return x
64
+
65
+
66
+ class WindowAttention(nn.Module):
67
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
68
+ It supports both of shifted and non-shifted window.
69
+
70
+ Args:
71
+ dim (int): Number of input channels.
72
+ window_size (tuple[int]): The height and width of the window.
73
+ num_heads (int): Number of attention heads.
74
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
75
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
76
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
77
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
78
+ """
79
+
80
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
81
+
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.window_size = window_size # Wh, Ww
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+ self.scale = qk_scale or head_dim ** -0.5
88
+
89
+ # define a parameter table of relative position bias
90
+ self.relative_position_bias_table = nn.Parameter(
91
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
92
+
93
+ # get pair-wise relative position index for each token inside the window
94
+ coords_h = torch.arange(self.window_size[0])
95
+ coords_w = torch.arange(self.window_size[1])
96
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
97
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
98
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
99
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
100
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
101
+ relative_coords[:, :, 1] += self.window_size[1] - 1
102
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
103
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ self.register_buffer("relative_position_index", relative_position_index)
105
+
106
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
107
+ self.attn_drop = nn.Dropout(attn_drop)
108
+ self.proj = nn.Linear(dim, dim)
109
+
110
+ self.proj_drop = nn.Dropout(proj_drop)
111
+
112
+ trunc_normal_(self.relative_position_bias_table, std=.02)
113
+ self.softmax = nn.Softmax(dim=-1)
114
+
115
+ def forward(self, x, mask=None):
116
+ """
117
+ Args:
118
+ x: input features with shape of (num_windows*B, N, C)
119
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
120
+ """
121
+ B_, N, C = x.shape
122
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
124
+
125
+ q = q * self.scale
126
+ attn = (q @ k.transpose(-2, -1))
127
+
128
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
129
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
130
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
131
+ attn = attn + relative_position_bias.unsqueeze(0)
132
+
133
+ if mask is not None:
134
+ nW = mask.shape[0]
135
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
136
+ attn = attn.view(-1, self.num_heads, N, N)
137
+ attn = self.softmax(attn)
138
+ else:
139
+ attn = self.softmax(attn)
140
+
141
+ attn = self.attn_drop(attn)
142
+
143
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
144
+ x = self.proj(x)
145
+ x = self.proj_drop(x)
146
+ return x
147
+
148
+ def extra_repr(self) -> str:
149
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
150
+
151
+ def flops(self, N):
152
+ # calculate flops for 1 window with token length of N
153
+ flops = 0
154
+ # qkv = self.qkv(x)
155
+ flops += N * self.dim * 3 * self.dim
156
+ # attn = (q @ k.transpose(-2, -1))
157
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
158
+ # x = (attn @ v)
159
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
160
+ # x = self.proj(x)
161
+ flops += N * self.dim * self.dim
162
+ return flops
163
+
164
+
165
+ class SwinTransformerBlock(nn.Module):
166
+ r""" Swin Transformer Block.
167
+
168
+ Args:
169
+ dim (int): Number of input channels.
170
+ input_resolution (tuple[int]): Input resulotion.
171
+ num_heads (int): Number of attention heads.
172
+ window_size (int): Window size.
173
+ shift_size (int): Shift size for SW-MSA.
174
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
175
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
176
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
177
+ drop (float, optional): Dropout rate. Default: 0.0
178
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
179
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
180
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
181
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
182
+ """
183
+
184
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
185
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
186
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
187
+ super().__init__()
188
+ self.dim = dim
189
+ self.input_resolution = input_resolution
190
+ self.num_heads = num_heads
191
+ self.window_size = window_size
192
+ self.shift_size = shift_size
193
+ self.mlp_ratio = mlp_ratio
194
+ if min(self.input_resolution) <= self.window_size:
195
+ # if window size is larger than input resolution, we don't partition windows
196
+ self.shift_size = 0
197
+ self.window_size = min(self.input_resolution)
198
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
199
+
200
+ self.norm1 = norm_layer(dim)
201
+ self.attn = WindowAttention(
202
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
203
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
204
+
205
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
206
+ self.norm2 = norm_layer(dim)
207
+ mlp_hidden_dim = int(dim * mlp_ratio)
208
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
209
+
210
+ if self.shift_size > 0:
211
+ attn_mask = self.calculate_mask(self.input_resolution)
212
+ else:
213
+ attn_mask = None
214
+
215
+ self.register_buffer("attn_mask", attn_mask)
216
+
217
+ def calculate_mask(self, x_size):
218
+ # calculate attention mask for SW-MSA
219
+ H, W = x_size
220
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
221
+ h_slices = (slice(0, -self.window_size),
222
+ slice(-self.window_size, -self.shift_size),
223
+ slice(-self.shift_size, None))
224
+ w_slices = (slice(0, -self.window_size),
225
+ slice(-self.window_size, -self.shift_size),
226
+ slice(-self.shift_size, None))
227
+ cnt = 0
228
+ for h in h_slices:
229
+ for w in w_slices:
230
+ img_mask[:, h, w, :] = cnt
231
+ cnt += 1
232
+
233
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
234
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
235
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
236
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
237
+
238
+ return attn_mask
239
+
240
+ def forward(self, x, x_size):
241
+ H, W = x_size
242
+ B, L, C = x.shape
243
+ # assert L == H * W, "input feature has wrong size"
244
+
245
+ shortcut = x
246
+ x = self.norm1(x)
247
+ x = x.view(B, H, W, C)
248
+
249
+ # cyclic shift
250
+ if self.shift_size > 0:
251
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
252
+ else:
253
+ shifted_x = x
254
+
255
+ # partition windows
256
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
257
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
258
+
259
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
260
+ if self.input_resolution == x_size:
261
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
262
+ else:
263
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
264
+
265
+ # merge windows
266
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
267
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
268
+
269
+ # reverse cyclic shift
270
+ if self.shift_size > 0:
271
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
272
+ else:
273
+ x = shifted_x
274
+ x = x.view(B, H * W, C)
275
+
276
+ # FFN
277
+ x = shortcut + self.drop_path(x)
278
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
279
+
280
+ return x
281
+
282
+ def extra_repr(self) -> str:
283
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
284
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
285
+
286
+ def flops(self):
287
+ flops = 0
288
+ H, W = self.input_resolution
289
+ # norm1
290
+ flops += self.dim * H * W
291
+ # W-MSA/SW-MSA
292
+ nW = H * W / self.window_size / self.window_size
293
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
294
+ # mlp
295
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
296
+ # norm2
297
+ flops += self.dim * H * W
298
+ return flops
299
+
300
+
301
+ class PatchMerging(nn.Module):
302
+ r""" Patch Merging Layer.
303
+
304
+ Args:
305
+ input_resolution (tuple[int]): Resolution of input feature.
306
+ dim (int): Number of input channels.
307
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
308
+ """
309
+
310
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
311
+ super().__init__()
312
+ self.input_resolution = input_resolution
313
+ self.dim = dim
314
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
315
+ self.norm = norm_layer(4 * dim)
316
+
317
+ def forward(self, x):
318
+ """
319
+ x: B, H*W, C
320
+ """
321
+ H, W = self.input_resolution
322
+ B, L, C = x.shape
323
+ assert L == H * W, "input feature has wrong size"
324
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
325
+
326
+ x = x.view(B, H, W, C)
327
+
328
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
329
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
330
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
331
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
332
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
333
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
334
+
335
+ x = self.norm(x)
336
+ x = self.reduction(x)
337
+
338
+ return x
339
+
340
+ def extra_repr(self) -> str:
341
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
342
+
343
+ def flops(self):
344
+ H, W = self.input_resolution
345
+ flops = H * W * self.dim
346
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
347
+ return flops
348
+
349
+
350
+ class BasicLayer(nn.Module):
351
+ """ A basic Swin Transformer layer for one stage.
352
+
353
+ Args:
354
+ dim (int): Number of input channels.
355
+ input_resolution (tuple[int]): Input resolution.
356
+ depth (int): Number of blocks.
357
+ num_heads (int): Number of attention heads.
358
+ window_size (int): Local window size.
359
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
360
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
361
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
362
+ drop (float, optional): Dropout rate. Default: 0.0
363
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
364
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
365
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
366
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
367
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
368
+ """
369
+
370
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
371
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
372
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
373
+
374
+ super().__init__()
375
+ self.dim = dim
376
+ self.input_resolution = input_resolution
377
+ self.depth = depth
378
+ self.use_checkpoint = use_checkpoint
379
+
380
+ # build blocks
381
+ self.blocks = nn.ModuleList([
382
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
383
+ num_heads=num_heads, window_size=window_size,
384
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
385
+ mlp_ratio=mlp_ratio,
386
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
387
+ drop=drop, attn_drop=attn_drop,
388
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
389
+ norm_layer=norm_layer)
390
+ for i in range(depth)])
391
+
392
+ # patch merging layer
393
+ if downsample is not None:
394
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
395
+ else:
396
+ self.downsample = None
397
+
398
+ def forward(self, x, x_size):
399
+ for blk in self.blocks:
400
+ if self.use_checkpoint:
401
+ x = checkpoint.checkpoint(blk, x, x_size)
402
+ else:
403
+ x = blk(x, x_size)
404
+ if self.downsample is not None:
405
+ x = self.downsample(x)
406
+ return x
407
+
408
+ def extra_repr(self) -> str:
409
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
410
+
411
+ def flops(self):
412
+ flops = 0
413
+ for blk in self.blocks:
414
+ flops += blk.flops()
415
+ if self.downsample is not None:
416
+ flops += self.downsample.flops()
417
+ return flops
418
+
419
+
420
+ class RSTB(nn.Module):
421
+ """Residual Swin Transformer Block (RSTB).
422
+
423
+ Args:
424
+ dim (int): Number of input channels.
425
+ input_resolution (tuple[int]): Input resolution.
426
+ depth (int): Number of blocks.
427
+ num_heads (int): Number of attention heads.
428
+ window_size (int): Local window size.
429
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
430
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
431
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
432
+ drop (float, optional): Dropout rate. Default: 0.0
433
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
434
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
435
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
436
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
437
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
438
+ img_size: Input image size.
439
+ patch_size: Patch size.
440
+ resi_connection: The convolutional block before residual connection.
441
+ """
442
+
443
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
444
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
445
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
446
+ img_size=224, patch_size=4, resi_connection='1conv'):
447
+ super(RSTB, self).__init__()
448
+
449
+ self.dim = dim
450
+ self.input_resolution = input_resolution
451
+
452
+ self.residual_group = BasicLayer(dim=dim,
453
+ input_resolution=input_resolution,
454
+ depth=depth,
455
+ num_heads=num_heads,
456
+ window_size=window_size,
457
+ mlp_ratio=mlp_ratio,
458
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
459
+ drop=drop, attn_drop=attn_drop,
460
+ drop_path=drop_path,
461
+ norm_layer=norm_layer,
462
+ downsample=downsample,
463
+ use_checkpoint=use_checkpoint)
464
+
465
+ if resi_connection == '1conv':
466
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
467
+ elif resi_connection == '3conv':
468
+ # to save parameters and memory
469
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.GELU(),
470
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
471
+ nn.GELU(),
472
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
473
+
474
+ self.patch_embed = PatchEmbed(
475
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
476
+ norm_layer=None)
477
+
478
+ self.patch_unembed = PatchUnEmbed(
479
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
480
+ norm_layer=None)
481
+
482
+ def forward(self, x, x_size):
483
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
484
+
485
+ def flops(self):
486
+ flops = 0
487
+ flops += self.residual_group.flops()
488
+ H, W = self.input_resolution
489
+ flops += H * W * self.dim * self.dim * 9
490
+ flops += self.patch_embed.flops()
491
+ flops += self.patch_unembed.flops()
492
+
493
+ return flops
494
+
495
+
496
+ class PatchEmbed(nn.Module):
497
+ r""" Image to Patch Embedding
498
+
499
+ Args:
500
+ img_size (int): Image size. Default: 224.
501
+ patch_size (int): Patch token size. Default: 4.
502
+ in_chans (int): Number of input image channels. Default: 3.
503
+ embed_dim (int): Number of linear projection output channels. Default: 96.
504
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
505
+ """
506
+
507
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
508
+ super().__init__()
509
+ img_size = to_2tuple(img_size)
510
+ patch_size = to_2tuple(patch_size)
511
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
512
+ self.img_size = img_size
513
+ self.patch_size = patch_size
514
+ self.patches_resolution = patches_resolution
515
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
516
+
517
+ self.in_chans = in_chans
518
+ self.embed_dim = embed_dim
519
+
520
+ if norm_layer is not None:
521
+ self.norm = norm_layer(embed_dim)
522
+ else:
523
+ self.norm = None
524
+
525
+ def forward(self, x):
526
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
527
+ if self.norm is not None:
528
+ x = self.norm(x)
529
+ return x
530
+
531
+ def flops(self):
532
+ flops = 0
533
+ H, W = self.img_size
534
+ if self.norm is not None:
535
+ flops += H * W * self.embed_dim
536
+ return flops
537
+
538
+
539
+ class PatchUnEmbed(nn.Module):
540
+ r""" Image to Patch Unembedding
541
+
542
+ Args:
543
+ img_size (int): Image size. Default: 224.
544
+ patch_size (int): Patch token size. Default: 4.
545
+ in_chans (int): Number of input image channels. Default: 3.
546
+ embed_dim (int): Number of linear projection output channels. Default: 96.
547
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
548
+ """
549
+
550
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
551
+ super().__init__()
552
+ img_size = to_2tuple(img_size)
553
+ patch_size = to_2tuple(patch_size)
554
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
555
+ self.img_size = img_size
556
+ self.patch_size = patch_size
557
+ self.patches_resolution = patches_resolution
558
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
559
+
560
+ self.in_chans = in_chans
561
+ self.embed_dim = embed_dim
562
+
563
+ def forward(self, x, x_size):
564
+ B, HW, C = x.shape
565
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
566
+ return x
567
+
568
+ def flops(self):
569
+ flops = 0
570
+ return flops
571
+
572
+
573
+ class Upsample(nn.Sequential):
574
+ """Upsample module.
575
+
576
+ Args:
577
+ scale (int): Scale factor. Supported scales: 2^n and 3.
578
+ num_feat (int): Channel number of intermediate features.
579
+ """
580
+
581
+ def __init__(self, scale, num_feat):
582
+ m = []
583
+ if (scale & (scale - 1)) == 0: # scale = 2^n
584
+ for _ in range(int(math.log(scale, 2))):
585
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
586
+ m.append(nn.PixelShuffle(2))
587
+ elif scale == 3:
588
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
589
+ m.append(nn.PixelShuffle(3))
590
+ else:
591
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
592
+ super(Upsample, self).__init__(*m)
593
+
594
+
595
+ class UpsampleOneStep(nn.Sequential):
596
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
597
+ Used in lightweight SR to save parameters.
598
+
599
+ Args:
600
+ scale (int): Scale factor. Supported scales: 2^n and 3.
601
+ num_feat (int): Channel number of intermediate features.
602
+
603
+ """
604
+
605
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
606
+ self.num_feat = num_feat
607
+ self.input_resolution = input_resolution
608
+ m = []
609
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
610
+ m.append(nn.PixelShuffle(scale))
611
+ super(UpsampleOneStep, self).__init__(*m)
612
+
613
+ def flops(self):
614
+ H, W = self.input_resolution
615
+ flops = H * W * self.num_feat * 3 * 9
616
+ return flops
617
+
618
+
619
+ class Generator(nn.Module):
620
+ r""" SwinIR
621
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
622
+
623
+ Args:
624
+ img_size (int | tuple(int)): Input image size. Default 64
625
+ patch_size (int | tuple(int)): Patch size. Default: 1
626
+ in_chans (int): Number of input image channels. Default: 3
627
+ embed_dim (int): Patch embedding dimension. Default: 96
628
+ depths (tuple(int)): Depth of each Swin Transformer layer.
629
+ num_heads (tuple(int)): Number of attention heads in different layers.
630
+ window_size (int): Window size. Default: 7
631
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
632
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
633
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
634
+ drop_rate (float): Dropout rate. Default: 0
635
+ attn_drop_rate (float): Attention dropout rate. Default: 0
636
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
637
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
638
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
639
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
640
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
641
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
642
+ img_range: Image range. 1. or 255.
643
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
644
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
645
+ """
646
+
647
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
648
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
649
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
650
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
651
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
652
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
653
+ **kwargs):
654
+ super(Generator, self).__init__()
655
+ num_in_ch = in_chans
656
+ num_out_ch = in_chans
657
+ num_feat = 64
658
+ self.img_range = img_range
659
+ if in_chans == 3:
660
+ rgb_mean = (0.4488, 0.4371, 0.4040)
661
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
662
+ else:
663
+ self.mean = torch.zeros(1, 1, 1, 1)
664
+ self.upscale = upscale
665
+ self.upsampler = upsampler
666
+ self.window_size = window_size
667
+
668
+ #####################################################################################################
669
+ ################################### 1, shallow feature extraction ###################################
670
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
671
+
672
+ #####################################################################################################
673
+ ################################### 2, deep feature extraction ######################################
674
+ self.num_layers = len(depths)
675
+ self.embed_dim = embed_dim
676
+ self.ape = ape
677
+ self.patch_norm = patch_norm
678
+ self.num_features = embed_dim
679
+ self.mlp_ratio = mlp_ratio
680
+
681
+ # split image into non-overlapping patches
682
+ self.patch_embed = PatchEmbed(
683
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
684
+ norm_layer=norm_layer if self.patch_norm else None)
685
+ num_patches = self.patch_embed.num_patches
686
+ patches_resolution = self.patch_embed.patches_resolution
687
+ self.patches_resolution = patches_resolution
688
+
689
+ # merge non-overlapping patches into image
690
+ self.patch_unembed = PatchUnEmbed(
691
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
692
+ norm_layer=norm_layer if self.patch_norm else None)
693
+
694
+ # absolute position embedding
695
+ if self.ape:
696
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
697
+ trunc_normal_(self.absolute_pos_embed, std=.02)
698
+
699
+ self.pos_drop = nn.Dropout(p=drop_rate)
700
+
701
+ # stochastic depth
702
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
703
+
704
+ # build Residual Swin Transformer blocks (RSTB)
705
+ self.layers = nn.ModuleList()
706
+ for i_layer in range(self.num_layers):
707
+ layer = RSTB(dim=embed_dim,
708
+ input_resolution=(patches_resolution[0],
709
+ patches_resolution[1]),
710
+ depth=depths[i_layer],
711
+ num_heads=num_heads[i_layer],
712
+ window_size=window_size,
713
+ mlp_ratio=self.mlp_ratio,
714
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
715
+ drop=drop_rate, attn_drop=attn_drop_rate,
716
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
717
+ norm_layer=norm_layer,
718
+ downsample=None,
719
+ use_checkpoint=use_checkpoint,
720
+ img_size=img_size,
721
+ patch_size=patch_size,
722
+ resi_connection=resi_connection
723
+
724
+ )
725
+ self.layers.append(layer)
726
+ self.norm = norm_layer(self.num_features)
727
+
728
+ # build the last conv layer in deep feature extraction
729
+ if resi_connection == '1conv':
730
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
731
+ elif resi_connection == '3conv':
732
+ # to save parameters and memory
733
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
734
+ nn.GELU(),
735
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
736
+ nn.GELU(),
737
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
738
+
739
+ #####################################################################################################
740
+ ################################ 3, high quality image reconstruction ################################
741
+ if self.upsampler == 'pixelshuffle':
742
+ # for classical SR
743
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
744
+ nn.GELU())
745
+ self.upsample = Upsample(upscale, num_feat)
746
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
747
+ elif self.upsampler == 'pixelshuffledirect':
748
+ # for lightweight SR (to save parameters)
749
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
750
+ (patches_resolution[0], patches_resolution[1]))
751
+ elif self.upsampler == 'nearest+conv':
752
+ # for real-world SR (less artifacts)
753
+ assert self.upscale == 4, 'only support x4 now.'
754
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
755
+ nn.GELU())
756
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
757
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
758
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
759
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
760
+ self.lrelu = nn.GELU()
761
+ else:
762
+ # for image denoising and JPEG compression artifact reduction
763
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
764
+
765
+ self.apply(self._init_weights)
766
+
767
+ def _init_weights(self, m):
768
+ if isinstance(m, nn.Linear):
769
+ trunc_normal_(m.weight, std=.02)
770
+ if isinstance(m, nn.Linear) and m.bias is not None:
771
+ nn.init.constant_(m.bias, 0)
772
+ elif isinstance(m, nn.LayerNorm):
773
+ nn.init.constant_(m.bias, 0)
774
+ nn.init.constant_(m.weight, 1.0)
775
+
776
+ @torch.jit.ignore
777
+ def no_weight_decay(self):
778
+ return {'absolute_pos_embed'}
779
+
780
+ @torch.jit.ignore
781
+ def no_weight_decay_keywords(self):
782
+ return {'relative_position_bias_table'}
783
+
784
+ def check_image_size(self, x):
785
+ _, _, h, w = x.size()
786
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
787
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
788
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
789
+ return x
790
+
791
+ def forward_features(self, x):
792
+ x_size = (x.shape[2], x.shape[3])
793
+ x = self.patch_embed(x)
794
+ if self.ape:
795
+ x = x + self.absolute_pos_embed
796
+ x = self.pos_drop(x)
797
+
798
+ for layer in self.layers:
799
+ x = layer(x, x_size)
800
+
801
+ x = self.norm(x) # B L C
802
+ x = self.patch_unembed(x, x_size)
803
+
804
+ return x
805
+
806
+ def forward(self, x):
807
+ H, W = x.shape[2:]
808
+ x = self.check_image_size(x)
809
+
810
+ self.mean = self.mean.type_as(x)
811
+ x = (x - self.mean) * self.img_range
812
+
813
+ if self.upsampler == 'pixelshuffle':
814
+ # for classical SR
815
+ x = self.conv_first(x)
816
+ x = self.conv_after_body(self.forward_features(x)) + x
817
+ x = self.conv_before_upsample(x)
818
+ x = self.conv_last(self.upsample(x))
819
+ elif self.upsampler == 'pixelshuffledirect':
820
+ # for lightweight SR
821
+ x = self.conv_first(x)
822
+ x = self.conv_after_body(self.forward_features(x)) + x
823
+ x = self.upsample(x)
824
+ elif self.upsampler == 'nearest+conv':
825
+ # for real-world SR
826
+ x = self.conv_first(x)
827
+ x = self.conv_after_body(self.forward_features(x)) + x
828
+ x = self.conv_before_upsample(x)
829
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
830
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
831
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
832
+ else:
833
+ # for image denoising and JPEG compression artifact reduction
834
+ x_first = self.conv_first(x)
835
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
836
+ x = x + self.conv_last(res)
837
+
838
+ x = x / self.img_range + self.mean
839
+
840
+ return x[:, :, :H*self.upscale, :W*self.upscale]
841
+
842
+ def flops(self):
843
+ flops = 0
844
+ H, W = self.patches_resolution
845
+ flops += H * W * 3 * self.embed_dim * 9
846
+ flops += self.patch_embed.flops()
847
+ for i, layer in enumerate(self.layers):
848
+ flops += layer.flops()
849
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
850
+ flops += self.upsample.flops()
851
+ return flops
852
+
853
+
854
+ class Discriminator(nn.Module):
855
+ def __init__(self):
856
+ super(Discriminator, self).__init__()
857
+ self.net = nn.Sequential(
858
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
859
+ nn.GELU(),
860
+
861
+ nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
862
+ nn.BatchNorm2d(64),
863
+ nn.GELU(),
864
+
865
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
866
+ nn.BatchNorm2d(128),
867
+ nn.GELU(),
868
+
869
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
870
+ nn.BatchNorm2d(128),
871
+ nn.GELU(),
872
+
873
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
874
+ nn.BatchNorm2d(256),
875
+ nn.GELU(),
876
+
877
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
878
+ nn.BatchNorm2d(256),
879
+ nn.GELU(),
880
+
881
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
882
+ nn.BatchNorm2d(512),
883
+ nn.GELU(),
884
+
885
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
886
+ nn.BatchNorm2d(512),
887
+ nn.GELU(),
888
+
889
+ nn.AdaptiveAvgPool2d(1),
890
+ nn.Conv2d(512, 1024, kernel_size=1),
891
+ nn.GELU(),
892
+ nn.Conv2d(1024, 1, kernel_size=1)
893
+ )
894
+
895
+ def forward(self, x):
896
+ batch_size = x.size(0)
897
+ return torch.sigmoid(self.net(x).view(batch_size))
898
+
899
+ if __name__ == '__main__':
900
+ upscale = 8
901
+ window_size = 8
902
+ height = (96 // upscale // window_size + 1) * window_size
903
+ width = (192 // upscale // window_size + 1) * window_size
904
+ model = Generator(upscale=upscale, img_size=(height, width),
905
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
906
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
907
+ print(model)
908
+ print(height, width, model.flops() / 1e9)
909
+
910
+ x = torch.randn((1, 3, height, width))
911
+ x = model(x)
912
+ print(x.shape)
nets/__pycache__/esrgan.cpython-38.pyc CHANGED
Binary files a/nets/__pycache__/esrgan.cpython-38.pyc and b/nets/__pycache__/esrgan.cpython-38.pyc differ
utils/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/__init__.cpython-38.pyc and b/utils/__pycache__/__init__.cpython-38.pyc differ
utils/__pycache__/dataloader.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/dataloader.cpython-38.pyc and b/utils/__pycache__/dataloader.cpython-38.pyc differ
utils/__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/utils.cpython-38.pyc and b/utils/__pycache__/utils.cpython-38.pyc differ
utils/__pycache__/utils_fit.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/utils_fit.cpython-38.pyc and b/utils/__pycache__/utils_fit.cpython-38.pyc differ
utils/dataloader.py CHANGED
@@ -1,12 +1,23 @@
1
- from random import randint
 
2
 
3
  import cv2
4
  import numpy as np
5
  from PIL import Image
6
  from torch.utils.data.dataset import Dataset
7
 
8
- from utils import cvtColor, preprocess_input
9
- from torch.utils.data import DataLoader
 
 
 
 
 
 
 
 
 
 
10
 
11
  def get_new_img_size(width, height, img_min_side=600):
12
  if width <= height:
@@ -29,6 +40,49 @@ class SRGANDataset(Dataset):
29
 
30
  self.lr_shape = lr_shape
31
  self.hr_shape = hr_shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def __len__(self):
34
  return self.train_batches
@@ -37,22 +91,20 @@ class SRGANDataset(Dataset):
37
  index = index % self.train_batches
38
 
39
  image_origin = Image.open(self.train_lines[index].split()[0])
40
- if self.rand()<.5:
41
- img_h = self.get_random_data(image_origin, self.hr_shape)
42
- else:
43
- img_h = self.random_crop(image_origin, self.hr_shape[1], self.hr_shape[0])
44
- img_l = img_h.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
45
 
46
- img_h = np.transpose(preprocess_input(np.array(img_h, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
47
- img_l = np.transpose(preprocess_input(np.array(img_l, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
48
- return np.array(img_l), np.array(img_h)
 
49
 
50
  def rand(self, a=0, b=1):
51
  return np.random.rand()*(b-a) + a
52
 
53
- def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
54
  #------------------------------#
55
  # 读取图像并转换成RGB图像
 
56
  #------------------------------#
57
  image = cvtColor(image)
58
  #------------------------------#
@@ -61,50 +113,19 @@ class SRGANDataset(Dataset):
61
  iw, ih = image.size
62
  h, w = input_shape
63
 
64
- if not random:
65
- scale = min(w/iw, h/ih)
66
- nw = int(iw*scale)
67
- nh = int(ih*scale)
68
- dx = (w-nw)//2
69
- dy = (h-nh)//2
70
-
71
- #---------------------------------#
72
- # 将图像多余的部分加上灰条
73
- #---------------------------------#
74
- image = image.resize((nw,nh), Image.BICUBIC)
75
- new_image = Image.new('RGB', (w,h), (128,128,128))
76
- new_image.paste(image, (dx, dy))
77
- image_data = np.array(new_image, np.float32)
78
-
79
- return image_data
80
-
81
- #------------------------------------------#
82
- # 对图像进行缩放并且进行长和宽的扭曲
83
- #------------------------------------------#
84
- new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
85
- scale = self.rand(1, 1.5)
86
- if new_ar < 1:
87
- nh = int(scale*h)
88
- nw = int(nh*new_ar)
89
- else:
90
- nw = int(scale*w)
91
- nh = int(nw/new_ar)
92
- image = image.resize((nw,nh), Image.BICUBIC)
93
 
94
- #------------------------------------------#
95
  # 将图像多余的部分加上灰条
96
- #------------------------------------------#
97
- dx = int(self.rand(0, w-nw))
98
- dy = int(self.rand(0, h-nh))
99
- new_image = Image.new('RGB', (w,h), (128,128,128))
100
  new_image.paste(image, (dx, dy))
101
- image = new_image
102
-
103
- #------------------------------------------#
104
- # 翻转图像
105
- #------------------------------------------#
106
- flip = self.rand()<.5
107
- if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
108
 
109
  rotate = self.rand()<.5
110
  if rotate:
@@ -113,41 +134,191 @@ class SRGANDataset(Dataset):
113
  M = cv2.getRotationMatrix2D((a,b),angle,1)
114
  image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
115
 
116
- #------------------------------------------#
117
- # 色域扭曲
118
- #------------------------------------------#
119
- hue = self.rand(-hue, hue)
120
- sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
121
- val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
122
- x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
123
- x[..., 1] *= sat
124
- x[..., 2] *= val
125
- x[x[:,:, 0]>360, 0] = 360
126
- x[:, :, 1:][x[:, :, 1:]>1] = 1
127
- x[x<0] = 0
128
- image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
129
- return Image.fromarray(np.uint8(image_data))
130
-
131
- def random_crop(self, image, width, height):
132
- #--------------------------------------------#
133
- # 如果图像过小无法截取,先对图像进行放大
134
- #--------------------------------------------#
135
- if image.size[0] < self.hr_shape[1] or image.size[1] < self.hr_shape[0]:
136
- resized_width, resized_height = get_new_img_size(width, height, img_min_side=np.max(self.hr_shape))
137
- image = image.resize((resized_width, resized_height), Image.BICUBIC)
138
-
139
- #--------------------------------------------#
140
- # 随机截取一部分
141
- #--------------------------------------------#
142
- width1 = randint(0, image.size[0] - width)
143
- height1 = randint(0, image.size[1] - height)
144
-
145
- width2 = width1 + width
146
- height2 = height1 + height
147
-
148
- image = image.crop((width1, height1, width2, height2))
149
- return image
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def SRGAN_dataset_collate(batch):
152
  images_l = []
153
  images_h = []
1
+ import math
2
+ from random import choice, choices, randint
3
 
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
7
  from torch.utils.data.dataset import Dataset
8
 
9
+ from utils import USMSharp_npy, cvtColor, preprocess_input
10
+
11
+ from .degradations import (circular_lowpass_kernel, random_add_gaussian_noise,
12
+ random_add_poisson_noise, random_mixed_kernels)
13
+ from .transforms import augment, paired_random_crop
14
+
15
+ def cv_show(image):
16
+ image = np.array(image)
17
+ image = cv2.resize(image, (256, 128), interpolation=cv2.INTER_CUBIC)
18
+ cv2.imshow('image', image)
19
+ cv2.waitKey(0)
20
+ cv2.destroyAllWindows()
21
 
22
  def get_new_img_size(width, height, img_min_side=600):
23
  if width <= height:
40
 
41
  self.lr_shape = lr_shape
42
  self.hr_shape = hr_shape
43
+ self.scale = int(hr_shape[0]/lr_shape[0])
44
+ self.usmsharp = USMSharp_npy()
45
+ # 第一次滤波的参数
46
+ self.blur_kernel_size = 21
47
+ self.kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48
+ self.kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49
+ self.sinc_prob = 0.1
50
+ self.blur_sigma = [0.2, 3]
51
+ self.betag_range = [0.5, 4]
52
+ self.betap_range = [1, 2]
53
+ # 第二次滤波的参数
54
+ self.blur_kernel_size2 = 21
55
+ self.kernel_list2 = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56
+ self.kernel_prob2 = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57
+ self.sinc_prob2 = 0.1
58
+ self.blur_sigma2 = [0.2, 3]
59
+ self.betag_range2 = [0.5, 4]
60
+ self.betap_range2 = [1, 2]
61
+ # 最后的sinc滤波
62
+ self.final_sinc_prob = 0.8
63
+ # 卷积核大小从7到21分布
64
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)]
65
+ # 使用脉冲张量进行卷积不会产生模糊效果
66
+ self.pulse_tensor = np.zeros(shape=[21, 21], dtype='float32')
67
+ self.pulse_tensor[10, 10] = 1
68
+ # 第一次退化的参数
69
+ self.resize_prob = [0.2, 0.7, 0.1] # up, down, keep
70
+ self.resize_range = [0.15, 1.5]
71
+ self.gaussian_noise_prob = 0.5
72
+ self.noise_range = [1, 30]
73
+ self.poisson_scale_range = [0.05, 3]
74
+ self.gray_noise_prob = 0.4
75
+ self.jpeg_range = [30, 95]
76
+
77
+ # 第二次退化的参数
78
+ self.second_blur_prob = 0.8
79
+ self.resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
80
+ self.resize_range2 = [0.3, 1.2]
81
+ self.gaussian_noise_prob2= 0.5
82
+ self.noise_range2 = [1, 25]
83
+ self.poisson_scale_range2= [0.05, 2.5]
84
+ self.gray_noise_prob2 = 0.4
85
+ self.jpeg_range2 = [30, 95]
86
 
87
  def __len__(self):
88
  return self.train_batches
91
  index = index % self.train_batches
92
 
93
  image_origin = Image.open(self.train_lines[index].split()[0])
94
+ lq, gt = self.get_random_data(image_origin, self.hr_shape)
 
 
 
 
95
 
96
+ gt = np.transpose(preprocess_input(np.array(gt, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
97
+ lq = np.transpose(preprocess_input(np.array(lq, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
98
+
99
+ return lq, gt
100
 
101
  def rand(self, a=0, b=1):
102
  return np.random.rand()*(b-a) + a
103
 
104
+ def get_random_data(self, image, input_shape):
105
  #------------------------------#
106
  # 读取图像并转换成RGB图像
107
+ # cvtColor将np转Image
108
  #------------------------------#
109
  image = cvtColor(image)
110
  #------------------------------#
113
  iw, ih = image.size
114
  h, w = input_shape
115
 
116
+ scale = min(w/iw, h/ih)
117
+ nw = int(iw*scale)
118
+ nh = int(ih*scale)
119
+ dx = (w-nw)//2
120
+ dy = (h-nh)//2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ #---------------------------------#
123
  # 将图像多余的部分加上灰条
124
+ #---------------------------------#
125
+ image = image.resize((nw,nh), Image.BICUBIC)
126
+ new_image = Image.new('RGB', (w,h), (128,128,128))
 
127
  new_image.paste(image, (dx, dy))
128
+ image = np.array(new_image, np.float32)
 
 
 
 
 
 
129
 
130
  rotate = self.rand()<.5
131
  if rotate:
134
  M = cv2.getRotationMatrix2D((a,b),angle,1)
135
  image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
136
 
137
+ # ------------------------ 生成卷积核以进行第一次退化处理 ------------------------ #
138
+ kernel_size = choice(self.kernel_range)
139
+ if np.random.uniform() < self.sinc_prob:
140
+ # 此sinc过滤器设置适用于[7,21]范围内的内核
141
+ if kernel_size < 13:
142
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
143
+ else:
144
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
145
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
146
+ else:
147
+ kernel = random_mixed_kernels(
148
+ self.kernel_list,
149
+ self.kernel_prob,
150
+ kernel_size,
151
+ self.blur_sigma,
152
+ self.blur_sigma, [-math.pi, math.pi],
153
+ self.betag_range,
154
+ self.betap_range,
155
+ noise_range=None)
156
+ # pad kernel
157
+ pad_size = (21 - kernel_size) // 2
158
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
159
+ kernel = kernel.astype(np.float32)
160
+ # ------------------------ 生成卷积核以进行第二次退化处理 ------------------------ #
161
+ kernel_size = choice(self.kernel_range)
162
+ if np.random.uniform() < self.sinc_prob2:
163
+ if kernel_size < 13:
164
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
165
+ else:
166
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
167
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
168
+ else:
169
+ kernel2 = random_mixed_kernels(
170
+ self.kernel_list2,
171
+ self.kernel_prob2,
172
+ kernel_size,
173
+ self.blur_sigma2,
174
+ self.blur_sigma2, [-math.pi, math.pi],
175
+ self.betag_range2,
176
+ self.betap_range2,
177
+ noise_range=None)
178
+ # pad kernel
179
+ pad_size = (21 - kernel_size) // 2
180
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
181
+ kernel2 = kernel2.astype(np.float32)
182
+ # ----------------------the final sinc kernel ------------------------- #
183
+ if np.random.uniform() < self.final_sinc_prob:
184
+ kernel_size = choice(self.kernel_range)
185
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
186
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
187
+ else:
188
+ sinc_kernel = self.pulse_tensor
189
+ sinc_kernel = sinc_kernel.astype(np.float32)
190
+ lq, gt = self.feed_data(image, kernel, kernel2, sinc_kernel)
191
+
192
+ return lq, gt
193
+
194
+ def feed_data(self, img_gt, kernel1, kernel2, sinc_kernel):
195
+
196
+ img_gt = np.array(img_gt, dtype=np.float32)
197
+ # 对gt进行锐化
198
+ img_gt = np.clip(img_gt / 255, 0, 1)
199
+ gt = self.usmsharp.filt(img_gt)
200
+ [ori_w, ori_h, _] = gt.shape
201
+
202
+ # ---------------------- 根据参数进行第一次退化 -------------------- #
203
+ # 模糊处理
204
+ out = cv2.filter2D(img_gt, -1, kernel1)
205
+ # 随机 resize
206
+ updown_type = choices(['up', 'down', 'keep'], self.resize_prob)[0]
207
+ if updown_type == 'up':
208
+ scale = np.random.uniform(1, self.resize_range[1])
209
+ elif updown_type == 'down':
210
+ scale = np.random.uniform(self.resize_range[0], 1)
211
+ else:
212
+ scale = 1
213
+ mode = choice(['area', 'bilinear', 'bicubic'])
214
+ if mode=='area':
215
+ out = cv2.resize(out, (int(ori_h * scale), int(ori_w * scale)), interpolation=cv2.INTER_AREA)
216
+ elif mode=='bilinear':
217
+ out = cv2.resize(out, (int(ori_h * scale), int(ori_w * scale)), interpolation=cv2.INTER_LINEAR)
218
+ else:
219
+ out = cv2.resize(out, (int(ori_h * scale), int(ori_w * scale)), interpolation=cv2.INTER_CUBIC)
220
+
221
+ # 灰度噪声
222
+ gray_noise_prob = self.gray_noise_prob
223
+ if np.random.uniform() < self.gaussian_noise_prob:
224
+ out = random_add_gaussian_noise(
225
+ out, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=gray_noise_prob)
226
+ else:
227
+ out = random_add_poisson_noise(
228
+ out,
229
+ scale_range=self.poisson_scale_range,
230
+ gray_prob=gray_noise_prob,
231
+ clip=True,
232
+ rounds=False)
233
+
234
+ # JPEG 压缩
235
+ jpeg_p = np.random.uniform(low=self.jpeg_range[0], high=self.jpeg_range[1])
236
+ jpeg_p = int(jpeg_p)
237
+ out = np.clip(out, 0, 1)
238
+
239
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
240
+ _, encimg = cv2.imencode('.jpg', out * 255., encode_param)
241
+ out = np.float32(cv2.imdecode(encimg, 1))/255
242
+
243
+ # ---------------------- 根据参数进行第一次退化 -------------------- #
244
+ # 模糊
245
+ if np.random.uniform() < self.second_blur_prob:
246
+ out = cv2.filter2D(out, -1, kernel2)
247
+ # 随机 resize
248
+ updown_type = choices(['up', 'down', 'keep'], self.resize_prob2)[0]
249
+ if updown_type == 'up':
250
+ scale = np.random.uniform(1, self.resize_range2[1])
251
+ elif updown_type == 'down':
252
+ scale = np.random.uniform(self.resize_range2[0], 1)
253
+ else:
254
+ scale = 1
255
+ mode = choice(['area', 'bilinear', 'bicubic'])
256
+ if mode == 'area':
257
+ out = cv2.resize(out, (int(ori_h / self.scale * scale), int(ori_w / self.scale * scale)), interpolation=cv2.INTER_AREA)
258
+ elif mode == 'bilinear':
259
+ out = cv2.resize(out, (int(ori_h / self.scale * scale), int(ori_w / self.scale * scale)), interpolation=cv2.INTER_LINEAR)
260
+ else:
261
+ out = cv2.resize(out, (int(ori_h / self.scale * scale), int(ori_w / self.scale * scale)), interpolation=cv2.INTER_CUBIC)
262
+ # 灰度噪声
263
+ gray_noise_prob = self.gray_noise_prob2
264
+ if np.random.uniform() < self.gaussian_noise_prob2:
265
+ out = random_add_gaussian_noise(
266
+ out, sigma_range=self.noise_range2, clip=True, rounds=False, gray_prob=gray_noise_prob)
267
+ else:
268
+ out = random_add_poisson_noise(
269
+ out,
270
+ scale_range=self.poisson_scale_range2,
271
+ gray_prob=gray_noise_prob,
272
+ clip=True,
273
+ rounds=False)
274
+
275
+ # JPEG压缩+最后的sinc滤波器
276
+ # 我们还需要将图像的大小调整到所需的尺寸。我们把[调整大小+sinc过滤器]组合在一起
277
+ # 作���一个操作。
278
+ # 我们考虑两个顺序。
279
+ # 1. [调整大小+sinc filter] + JPEG压缩
280
+ # 2. 2. JPEG压缩+[调整大小+sinc过滤]。
281
+ # 根据经验,我们发现其他组合(sinc + JPEG + Resize)会引入扭曲的线条。
282
+ if np.random.uniform() < 0.5:
283
+ # resize back + the final sinc filter
284
+ mode = choice(['area', 'bilinear', 'bicubic'])
285
+ if mode == 'area':
286
+ out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale), interpolation=cv2.INTER_AREA)
287
+ elif mode == 'bilinear':
288
+ out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale), interpolation=cv2.INTER_LINEAR)
289
+ else:
290
+ out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale), interpolation=cv2.INTER_CUBIC)
291
+
292
+ out = cv2.filter2D(out, -1, sinc_kernel)
293
+ # JPEG 压缩
294
+ jpeg_p = np.random.uniform(low=self.jpeg_range[0], high=self.jpeg_range[1])
295
+ jpeg_p = jpeg_p
296
+ out = np.clip(out, 0, 1)
297
+
298
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
299
+ _, encimg = cv2.imencode('.jpg', out * 255., encode_param)
300
+ out = np.float32(cv2.imdecode(encimg, 1)) / 255
301
+ else:
302
+ # JPEG 压缩
303
+ jpeg_p = np.random.uniform(low=self.jpeg_range[0], high=self.jpeg_range[1])
304
+ jpeg_p = jpeg_p
305
+ out = np.clip(out, 0, 1)
306
+
307
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
308
+ _, encimg = cv2.imencode('.jpg', out * 255., encode_param)
309
+ out = np.float32(cv2.imdecode(encimg, 1)) / 255
310
+ # resize back + the final sinc filter
311
+ mode = choice(['area', 'bilinear', 'bicubic'])
312
+ if mode == 'area':
313
+ out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale),interpolation=cv2.INTER_AREA)
314
+ elif mode == 'bilinear':
315
+ out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale),interpolation=cv2.INTER_LINEAR)
316
+ else:
317
+ out = cv2.resize(out, (ori_h // self.scale, ori_w // self.scale),interpolation=cv2.INTER_CUBIC)
318
+ lq = np.clip((out * 255.0), 0, 255)
319
+ gt = np.clip((gt * 255.0), 0, 255)
320
+ return Image.fromarray(np.uint8(lq)), Image.fromarray(np.uint8(gt))
321
+
322
  def SRGAN_dataset_collate(batch):
323
  images_l = []
324
  images_h = []
utils/degradations.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from scipy import special
7
+ from scipy.stats import multivariate_normal
8
+ from torchvision.transforms.functional_tensor import rgb_to_grayscale
9
+
10
+ # -------------------------------------------------------------------- #
11
+ # --------------------------- blur kernels --------------------------- #
12
+ # -------------------------------------------------------------------- #
13
+
14
+
15
+ # --------------------------- util functions --------------------------- #
16
+ def sigma_matrix2(sig_x, sig_y, theta):
17
+ """Calculate the rotated sigma matrix (two dimensional matrix).
18
+
19
+ Args:
20
+ sig_x (float):
21
+ sig_y (float):
22
+ theta (float): Radian measurement.
23
+
24
+ Returns:
25
+ ndarray: Rotated sigma matrix.
26
+ """
27
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
28
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
29
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
30
+
31
+
32
+ def mesh_grid(kernel_size):
33
+ """Generate the mesh grid, centering at zero.
34
+
35
+ Args:
36
+ kernel_size (int):
37
+
38
+ Returns:
39
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
40
+ xx (ndarray): with the shape (kernel_size, kernel_size)
41
+ yy (ndarray): with the shape (kernel_size, kernel_size)
42
+ """
43
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
44
+ xx, yy = np.meshgrid(ax, ax)
45
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
46
+ 1))).reshape(kernel_size, kernel_size, 2)
47
+ return xy, xx, yy
48
+
49
+
50
+ def pdf2(sigma_matrix, grid):
51
+ """Calculate PDF of the bivariate Gaussian distribution.
52
+
53
+ Args:
54
+ sigma_matrix (ndarray): with the shape (2, 2)
55
+ grid (ndarray): generated by :func:`mesh_grid`,
56
+ with the shape (K, K, 2), K is the kernel size.
57
+
58
+ Returns:
59
+ kernel (ndarrray): un-normalized kernel.
60
+ """
61
+ inverse_sigma = np.linalg.inv(sigma_matrix)
62
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
63
+ return kernel
64
+
65
+
66
+ def cdf2(d_matrix, grid):
67
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
68
+ Used in skewed Gaussian distribution.
69
+
70
+ Args:
71
+ d_matrix (ndarrasy): skew matrix.
72
+ grid (ndarray): generated by :func:`mesh_grid`,
73
+ with the shape (K, K, 2), K is the kernel size.
74
+
75
+ Returns:
76
+ cdf (ndarray): skewed cdf.
77
+ """
78
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
79
+ grid = np.dot(grid, d_matrix)
80
+ cdf = rv.cdf(grid)
81
+ return cdf
82
+
83
+
84
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
85
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
86
+
87
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
88
+
89
+ Args:
90
+ kernel_size (int):
91
+ sig_x (float):
92
+ sig_y (float):
93
+ theta (float): Radian measurement.
94
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
95
+ with the shape (K, K, 2), K is the kernel size. Default: None
96
+ isotropic (bool):
97
+
98
+ Returns:
99
+ kernel (ndarray): normalized kernel.
100
+ """
101
+ if grid is None:
102
+ grid, _, _ = mesh_grid(kernel_size)
103
+ if isotropic:
104
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
105
+ else:
106
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
107
+ kernel = pdf2(sigma_matrix, grid)
108
+ kernel = kernel / np.sum(kernel)
109
+ return kernel
110
+
111
+
112
+ def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
113
+ """Generate a bivariate generalized Gaussian kernel.
114
+ Described in `Parameter Estimation For Multivariate Generalized
115
+ Gaussian Distributions`_
116
+ by Pascal et. al (2013).
117
+
118
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
119
+
120
+ Args:
121
+ kernel_size (int):
122
+ sig_x (float):
123
+ sig_y (float):
124
+ theta (float): Radian measurement.
125
+ beta (float): shape parameter, beta = 1 is the normal distribution.
126
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
127
+ with the shape (K, K, 2), K is the kernel size. Default: None
128
+
129
+ Returns:
130
+ kernel (ndarray): normalized kernel.
131
+
132
+ .. _Parameter Estimation For Multivariate Generalized Gaussian
133
+ Distributions: https://arxiv.org/abs/1302.6498
134
+ """
135
+ if grid is None:
136
+ grid, _, _ = mesh_grid(kernel_size)
137
+ if isotropic:
138
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
139
+ else:
140
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
141
+ inverse_sigma = np.linalg.inv(sigma_matrix)
142
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
143
+ kernel = kernel / np.sum(kernel)
144
+ return kernel
145
+
146
+
147
+ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
148
+ """Generate a plateau-like anisotropic kernel.
149
+ 1 / (1+x^(beta))
150
+
151
+ Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
152
+
153
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
154
+
155
+ Args:
156
+ kernel_size (int):
157
+ sig_x (float):
158
+ sig_y (float):
159
+ theta (float): Radian measurement.
160
+ beta (float): shape parameter, beta = 1 is the normal distribution.
161
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
162
+ with the shape (K, K, 2), K is the kernel size. Default: None
163
+
164
+ Returns:
165
+ kernel (ndarray): normalized kernel.
166
+ """
167
+ if grid is None:
168
+ grid, _, _ = mesh_grid(kernel_size)
169
+ if isotropic:
170
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
171
+ else:
172
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
173
+ inverse_sigma = np.linalg.inv(sigma_matrix)
174
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
175
+ kernel = kernel / np.sum(kernel)
176
+ return kernel
177
+
178
+
179
+ def random_bivariate_Gaussian(kernel_size,
180
+ sigma_x_range,
181
+ sigma_y_range,
182
+ rotation_range,
183
+ noise_range=None,
184
+ isotropic=True):
185
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
186
+
187
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
188
+
189
+ Args:
190
+ kernel_size (int):
191
+ sigma_x_range (tuple): [0.6, 5]
192
+ sigma_y_range (tuple): [0.6, 5]
193
+ rotation range (tuple): [-math.pi, math.pi]
194
+ noise_range(tuple, optional): multiplicative kernel noise,
195
+ [0.75, 1.25]. Default: None
196
+
197
+ Returns:
198
+ kernel (ndarray):
199
+ """
200
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
201
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
202
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
203
+ if isotropic is False:
204
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
205
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
206
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
207
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
208
+ else:
209
+ sigma_y = sigma_x
210
+ rotation = 0
211
+
212
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
213
+
214
+ # add multiplicative noise
215
+ if noise_range is not None:
216
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
217
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
218
+ kernel = kernel * noise
219
+ kernel = kernel / np.sum(kernel)
220
+ return kernel
221
+
222
+
223
+ def random_bivariate_generalized_Gaussian(kernel_size,
224
+ sigma_x_range,
225
+ sigma_y_range,
226
+ rotation_range,
227
+ beta_range,
228
+ noise_range=None,
229
+ isotropic=True):
230
+ """Randomly generate bivariate generalized Gaussian kernels.
231
+
232
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
233
+
234
+ Args:
235
+ kernel_size (int):
236
+ sigma_x_range (tuple): [0.6, 5]
237
+ sigma_y_range (tuple): [0.6, 5]
238
+ rotation range (tuple): [-math.pi, math.pi]
239
+ beta_range (tuple): [0.5, 8]
240
+ noise_range(tuple, optional): multiplicative kernel noise,
241
+ [0.75, 1.25]. Default: None
242
+
243
+ Returns:
244
+ kernel (ndarray):
245
+ """
246
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
247
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
248
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
249
+ if isotropic is False:
250
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
251
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
252
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
253
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
254
+ else:
255
+ sigma_y = sigma_x
256
+ rotation = 0
257
+
258
+ # assume beta_range[0] < 1 < beta_range[1]
259
+ if np.random.uniform() < 0.5:
260
+ beta = np.random.uniform(beta_range[0], 1)
261
+ else:
262
+ beta = np.random.uniform(1, beta_range[1])
263
+
264
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
265
+
266
+ # add multiplicative noise
267
+ if noise_range is not None:
268
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
269
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
270
+ kernel = kernel * noise
271
+ kernel = kernel / np.sum(kernel)
272
+ return kernel
273
+
274
+
275
+ def random_bivariate_plateau(kernel_size,
276
+ sigma_x_range,
277
+ sigma_y_range,
278
+ rotation_range,
279
+ beta_range,
280
+ noise_range=None,
281
+ isotropic=True):
282
+ """Randomly generate bivariate plateau kernels.
283
+
284
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
285
+
286
+ Args:
287
+ kernel_size (int):
288
+ sigma_x_range (tuple): [0.6, 5]
289
+ sigma_y_range (tuple): [0.6, 5]
290
+ rotation range (tuple): [-math.pi/2, math.pi/2]
291
+ beta_range (tuple): [1, 4]
292
+ noise_range(tuple, optional): multiplicative kernel noise,
293
+ [0.75, 1.25]. Default: None
294
+
295
+ Returns:
296
+ kernel (ndarray):
297
+ """
298
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
299
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
300
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
301
+ if isotropic is False:
302
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
303
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
304
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
305
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
306
+ else:
307
+ sigma_y = sigma_x
308
+ rotation = 0
309
+
310
+ # TODO: this may be not proper
311
+ if np.random.uniform() < 0.5:
312
+ beta = np.random.uniform(beta_range[0], 1)
313
+ else:
314
+ beta = np.random.uniform(1, beta_range[1])
315
+
316
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
317
+ # add multiplicative noise
318
+ if noise_range is not None:
319
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
320
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
321
+ kernel = kernel * noise
322
+ kernel = kernel / np.sum(kernel)
323
+
324
+ return kernel
325
+
326
+
327
+ def random_mixed_kernels(kernel_list,
328
+ kernel_prob,
329
+ kernel_size=21,
330
+ sigma_x_range=(0.6, 5),
331
+ sigma_y_range=(0.6, 5),
332
+ rotation_range=(-math.pi, math.pi),
333
+ betag_range=(0.5, 8),
334
+ betap_range=(0.5, 8),
335
+ noise_range=None):
336
+ """Randomly generate mixed kernels.
337
+
338
+ Args:
339
+ kernel_list (tuple): a list name of kernel types,
340
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
341
+ 'plateau_aniso']
342
+ kernel_prob (tuple): corresponding kernel probability for each
343
+ kernel type
344
+ kernel_size (int):
345
+ sigma_x_range (tuple): [0.6, 5]
346
+ sigma_y_range (tuple): [0.6, 5]
347
+ rotation range (tuple): [-math.pi, math.pi]
348
+ beta_range (tuple): [0.5, 8]
349
+ noise_range(tuple, optional): multiplicative kernel noise,
350
+ [0.75, 1.25]. Default: None
351
+
352
+ Returns:
353
+ kernel (ndarray):
354
+ """
355
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
356
+ if kernel_type == 'iso':
357
+ kernel = random_bivariate_Gaussian(
358
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
359
+ elif kernel_type == 'aniso':
360
+ kernel = random_bivariate_Gaussian(
361
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
362
+ elif kernel_type == 'generalized_iso':
363
+ kernel = random_bivariate_generalized_Gaussian(
364
+ kernel_size,
365
+ sigma_x_range,
366
+ sigma_y_range,
367
+ rotation_range,
368
+ betag_range,
369
+ noise_range=noise_range,
370
+ isotropic=True)
371
+ elif kernel_type == 'generalized_aniso':
372
+ kernel = random_bivariate_generalized_Gaussian(
373
+ kernel_size,
374
+ sigma_x_range,
375
+ sigma_y_range,
376
+ rotation_range,
377
+ betag_range,
378
+ noise_range=noise_range,
379
+ isotropic=False)
380
+ elif kernel_type == 'plateau_iso':
381
+ kernel = random_bivariate_plateau(
382
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
383
+ elif kernel_type == 'plateau_aniso':
384
+ kernel = random_bivariate_plateau(
385
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
386
+ return kernel
387
+
388
+
389
+ np.seterr(divide='ignore', invalid='ignore')
390
+
391
+
392
+ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
393
+ """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
394
+
395
+ Args:
396
+ cutoff (float): cutoff frequency in radians (pi is max)
397
+ kernel_size (int): horizontal and vertical size, must be odd.
398
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
399
+ """
400
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
401
+ kernel = np.fromfunction(
402
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
403
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
404
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
405
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
406
+ kernel = kernel / np.sum(kernel)
407
+ if pad_to > kernel_size:
408
+ pad_size = (pad_to - kernel_size) // 2
409
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
410
+ return kernel
411
+
412
+
413
+ # ------------------------------------------------------------- #
414
+ # --------------------------- noise --------------------------- #
415
+ # ------------------------------------------------------------- #
416
+
417
+ # ----------------------- Gaussian Noise ----------------------- #
418
+
419
+
420
+ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
421
+ """Generate Gaussian noise.
422
+
423
+ Args:
424
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
425
+ sigma (float): Noise scale (measured in range 255). Default: 10.
426
+
427
+ Returns:
428
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
429
+ float32.
430
+ """
431
+ if gray_noise:
432
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
433
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
434
+ else:
435
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
436
+ return noise
437
+
438
+
439
+ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
440
+ """Add Gaussian noise.
441
+
442
+ Args:
443
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
444
+ sigma (float): Noise scale (measured in range 255). Default: 10.
445
+
446
+ Returns:
447
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
448
+ float32.
449
+ """
450
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
451
+ out = img + noise
452
+ if clip and rounds:
453
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
454
+ elif clip:
455
+ out = np.clip(out, 0, 1)
456
+ elif rounds:
457
+ out = (out * 255.0).round() / 255.
458
+ return out
459
+
460
+
461
+ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
462
+ """Add Gaussian noise (PyTorch version).
463
+
464
+ Args:
465
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
466
+ scale (float | Tensor): Noise scale. Default: 1.0.
467
+
468
+ Returns:
469
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
470
+ float32.
471
+ """
472
+ b, _, h, w = img.size()
473
+ if not isinstance(sigma, (float, int)):
474
+ sigma = sigma.view(img.size(0), 1, 1, 1)
475
+ if isinstance(gray_noise, (float, int)):
476
+ cal_gray_noise = gray_noise > 0
477
+ else:
478
+ gray_noise = gray_noise.view(b, 1, 1, 1)
479
+ cal_gray_noise = torch.sum(gray_noise) > 0
480
+
481
+ if cal_gray_noise:
482
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
483
+ noise_gray = noise_gray.view(b, 1, h, w)
484
+
485
+ # always calculate color noise
486
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
487
+
488
+ if cal_gray_noise:
489
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
490
+ return noise
491
+
492
+
493
+ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
494
+ """Add Gaussian noise (PyTorch version).
495
+
496
+ Args:
497
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
498
+ scale (float | Tensor): Noise scale. Default: 1.0.
499
+
500
+ Returns:
501
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
502
+ float32.
503
+ """
504
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
505
+ out = img + noise
506
+ if clip and rounds:
507
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
508
+ elif clip:
509
+ out = torch.clamp(out, 0, 1)
510
+ elif rounds:
511
+ out = (out * 255.0).round() / 255.
512
+ return out
513
+
514
+
515
+ # ----------------------- Random Gaussian Noise ----------------------- #
516
+ def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
517
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
518
+ if np.random.uniform() < gray_prob:
519
+ gray_noise = True
520
+ else:
521
+ gray_noise = False
522
+ return generate_gaussian_noise(img, sigma, gray_noise)
523
+
524
+
525
+ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
526
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
527
+ out = img + noise
528
+ if clip and rounds:
529
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
530
+ elif clip:
531
+ out = np.clip(out, 0, 1)
532
+ elif rounds:
533
+ out = (out * 255.0).round() / 255.
534
+ return out
535
+
536
+
537
+ def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
538
+ sigma = torch.rand(
539
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
540
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
541
+ gray_noise = (gray_noise < gray_prob).float()
542
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
543
+
544
+
545
+ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
546
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
547
+ out = img + noise
548
+ if clip and rounds:
549
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
550
+ elif clip:
551
+ out = torch.clamp(out, 0, 1)
552
+ elif rounds:
553
+ out = (out * 255.0).round() / 255.
554
+ return out
555
+
556
+
557
+ # ----------------------- Poisson (Shot) Noise ----------------------- #
558
+
559
+
560
+ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
561
+ """Generate poisson noise.
562
+
563
+ Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
564
+
565
+ Args:
566
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
567
+ scale (float): Noise scale. Default: 1.0.
568
+ gray_noise (bool): Whether generate gray noise. Default: False.
569
+
570
+ Returns:
571
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
572
+ float32.
573
+ """
574
+ if gray_noise:
575
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
576
+ # round and clip image for counting vals correctly
577
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
578
+ vals = len(np.unique(img))
579
+ vals = 2**np.ceil(np.log2(vals))
580
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
581
+ noise = out - img
582
+ if gray_noise:
583
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
584
+ return noise * scale
585
+
586
+
587
+ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
588
+ """Add poisson noise.
589
+
590
+ Args:
591
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
592
+ scale (float): Noise scale. Default: 1.0.
593
+ gray_noise (bool): Whether generate gray noise. Default: False.
594
+
595
+ Returns:
596
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
597
+ float32.
598
+ """
599
+ noise = generate_poisson_noise(img, scale, gray_noise)
600
+ out = img + noise
601
+ if clip and rounds:
602
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
603
+ elif clip:
604
+ out = np.clip(out, 0, 1)
605
+ elif rounds:
606
+ out = (out * 255.0).round() / 255.
607
+ return out
608
+
609
+
610
+ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
611
+ """Generate a batch of poisson noise (PyTorch version)
612
+
613
+ Args:
614
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
615
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
616
+ Default: 1.0.
617
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
618
+ 0 for False, 1 for True. Default: 0.
619
+
620
+ Returns:
621
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
622
+ float32.
623
+ """
624
+ b, _, h, w = img.size()
625
+ if isinstance(gray_noise, (float, int)):
626
+ cal_gray_noise = gray_noise > 0
627
+ else:
628
+ gray_noise = gray_noise.view(b, 1, 1, 1)
629
+ cal_gray_noise = torch.sum(gray_noise) > 0
630
+ if cal_gray_noise:
631
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
632
+ # round and clip image for counting vals correctly
633
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
634
+ # use for-loop to get the unique values for each sample
635
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
636
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
637
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
638
+ out = torch.poisson(img_gray * vals) / vals
639
+ noise_gray = out - img_gray
640
+ noise_gray = noise_gray.expand(b, 3, h, w)
641
+
642
+ # always calculate color noise
643
+ # round and clip image for counting vals correctly
644
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
645
+ # use for-loop to get the unique values for each sample
646
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
647
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
648
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
649
+ out = torch.poisson(img * vals) / vals
650
+ noise = out - img
651
+ if cal_gray_noise:
652
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
653
+ if not isinstance(scale, (float, int)):
654
+ scale = scale.view(b, 1, 1, 1)
655
+ return noise * scale
656
+
657
+
658
+ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
659
+ """Add poisson noise to a batch of images (PyTorch version).
660
+
661
+ Args:
662
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
663
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
664
+ Default: 1.0.
665
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
666
+ 0 for False, 1 for True. Default: 0.
667
+
668
+ Returns:
669
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
670
+ float32.
671
+ """
672
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
673
+ out = img + noise
674
+ if clip and rounds:
675
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
676
+ elif clip:
677
+ out = torch.clamp(out, 0, 1)
678
+ elif rounds:
679
+ out = (out * 255.0).round() / 255.
680
+ return out
681
+
682
+
683
+ # ----------------------- Random Poisson (Shot) Noise ----------------------- #
684
+
685
+
686
+ def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
687
+ scale = np.random.uniform(scale_range[0], scale_range[1])
688
+ if np.random.uniform() < gray_prob:
689
+ gray_noise = True
690
+ else:
691
+ gray_noise = False
692
+ return generate_poisson_noise(img, scale, gray_noise)
693
+
694
+
695
+ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
696
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
697
+ out = img + noise
698
+ if clip and rounds:
699
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
700
+ elif clip:
701
+ out = np.clip(out, 0, 1)
702
+ elif rounds:
703
+ out = (out * 255.0).round() / 255.
704
+ return out
705
+
706
+
707
+ def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
708
+ scale = torch.rand(
709
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
710
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
711
+ gray_noise = (gray_noise < gray_prob).float()
712
+ return generate_poisson_noise_pt(img, scale, gray_noise)
713
+
714
+
715
+ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
716
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
717
+ out = img + noise
718
+ if clip and rounds:
719
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
720
+ elif clip:
721
+ out = torch.clamp(out, 0, 1)
722
+ elif rounds:
723
+ out = (out * 255.0).round() / 255.
724
+ return out
725
+
726
+
727
+ # ------------------------------------------------------------------------ #
728
+ # --------------------------- JPEG compression --------------------------- #
729
+ # ------------------------------------------------------------------------ #
730
+
731
+
732
+ def add_jpg_compression(img, quality=90):
733
+ """Add JPG compression artifacts.
734
+
735
+ Args:
736
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
737
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
738
+ best quality. Default: 90.
739
+
740
+ Returns:
741
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
742
+ float32.
743
+ """
744
+ img = np.clip(img, 0, 1)
745
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
746
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
747
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
748
+ return img
749
+
750
+
751
+ def random_add_jpg_compression(img, quality_range=(90, 100)):
752
+ """Randomly add JPG compression artifacts.
753
+
754
+ Args:
755
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
756
+ quality_range (tuple[float] | list[float]): JPG compression quality
757
+ range. 0 for lowest quality, 100 for best quality.
758
+ Default: (90, 100).
759
+
760
+ Returns:
761
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
762
+ float32.
763
+ """
764
+ quality = np.random.uniform(quality_range[0], quality_range[1])
765
+ return add_jpg_compression(img, quality)
utils/preprocess.py CHANGED
@@ -9,8 +9,56 @@ from dask import bag as dbag
9
  from dask.diagnostics import ProgressBar
10
  from typing import Tuple
11
  from PIL import Image
12
-
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Dataset statistics that I gathered in development
16
  #-----------------------------------#
@@ -47,7 +95,9 @@ def parseLabel(label: str) -> Tuple[np.ndarray, np.ndarray]:
47
  #-----------------------------------#
48
  # 根据车牌坐标裁剪出车牌图像
49
  #-----------------------------------#
50
-
 
 
51
 
52
  def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
53
  maxW = np.max(coor[:, 0] - center[0]) # max plate width
@@ -63,7 +113,7 @@ def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.nda
63
  maxW = w//2
64
  found = True
65
  break
66
- if not found: # 车牌太大则丢弃
67
  return np.array([])
68
  elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
69
  center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
@@ -107,10 +157,10 @@ def processImage(file: str, inputDir: str, outputDir: str, subFolder: str) -> in
107
  return 0
108
  mean = np.mean(plate/255.0)
109
  std = np.std(plate/255.0)
110
- # 亮度不好的
111
  if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
112
  return 0
113
- # 低对比度的
114
  if std <= IMG_STD - 10*IMG_STD_STD:
115
  return 0
116
  status = saveImage(plate, file, outputDir)
@@ -126,26 +176,27 @@ def main(argv):
126
  for shape in ['64_32', '128_64', '192_96']:
127
  os.mkdir(os.path.join(outputDir, shape))
128
  except OSError:
129
- pass # 地址已经存在
130
- client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # 开启多线程
131
- for subFolder in ['ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']:
 
132
  fileList = os.listdir(os.path.join(inputDir, subFolder))
133
  print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
134
  toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
135
  toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
136
  pbar = ProgressBar(minimum=2.0)
137
- pbar.register() # 登记所有的计算,以便更好地跟踪
138
  result = toDo.compute()
139
  print('* image cropped: {}. Done ...'.format(sum(result)))
140
- client.close() # 关闭集群
141
 
142
 
143
  if __name__ == "__main__":
144
  parser = argparse.ArgumentParser(description=__doc__)
145
  add_arg = functools.partial(add_arguments, argparser=parser)
146
  add_arg('jobNum', int, 4, '处理图片的线程数')
147
- add_arg('inputDir', str, 'datasets/CCPD2019', '输入图片目录')
148
- add_arg('outputDir', str, 'datasets/CCPD2019_new', '保存图��目录')
149
  args = parser.parse_args()
150
  print_arguments(args)
151
  main(args)
9
  from dask.diagnostics import ProgressBar
10
  from typing import Tuple
11
  from PIL import Image
12
+ import cv2
13
+ #-----------------------------------#
14
+ # 对四个点坐标排序
15
+ #-----------------------------------#
16
+ def order_points(pts):
17
+ # 一共4个坐标点
18
+ rect = np.zeros((4, 2), dtype = "float32")
19
+
20
+ # 按顺序找到对应坐标0123分别是 左上,右上,右下,左下
21
+ # 计算左上,右下
22
+ s = pts.sum(axis = 1)
23
+ rect[0] = pts[np.argmin(s)]
24
+ rect[2] = pts[np.argmax(s)]
25
+
26
+ # 计算右上和左下
27
+ diff = np.diff(pts, axis = 1)
28
+ rect[1] = pts[np.argmin(diff)]
29
+ rect[3] = pts[np.argmax(diff)]
30
+
31
+ return rect
32
+ #-----------------------------------#
33
+ # 透射变换纠正车牌图片
34
+ #-----------------------------------#
35
+ def four_point_transform(image, pts):
36
+ # 获取输入坐标点
37
+ rect = order_points(pts)
38
+ (tl, tr, br, bl) = rect
39
+
40
+ # 计算输入的w和h值
41
+ widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
42
+ widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
43
+ maxWidth = max(int(widthA), int(widthB))
44
+
45
+ heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
46
+ heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
47
+ maxHeight = max(int(heightA), int(heightB))
48
+
49
+ # 变换后对应坐标位置
50
+ dst = np.array([
51
+ [0, 0],
52
+ [maxWidth - 1, 0],
53
+ [maxWidth - 1, maxHeight - 1],
54
+ [0, maxHeight - 1]], dtype = "float32")
55
+
56
+ # 计算变换矩阵
57
+ M = cv2.getPerspectiveTransform(rect, dst)
58
+ warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
59
+
60
+ # 返回变换后结果
61
+ return warped
62
 
63
  # Dataset statistics that I gathered in development
64
  #-----------------------------------#
95
  #-----------------------------------#
96
  # 根据车牌坐标裁剪出车牌图像
97
  #-----------------------------------#
98
+ # def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
99
+ # image = four_point_transform(image, coor)
100
+ # return image
101
 
102
  def cropImage(image: np.ndarray, coor: np.ndarray, center: np.ndarray) -> np.ndarray:
103
  maxW = np.max(coor[:, 0] - center[0]) # max plate width
113
  maxW = w//2
114
  found = True
115
  break
116
+ if not found: # plate too large, discard
117
  return np.array([])
118
  elif center[1]-maxH < 0 or center[1]+maxH >= image.shape[1] or \
119
  center[0]-maxW < 0 or center[0] + maxW >= image.shape[0]:
157
  return 0
158
  mean = np.mean(plate/255.0)
159
  std = np.std(plate/255.0)
160
+ # bad brightness
161
  if mean <= IMAGE_MEAN - 10*IMAGE_MEAN_STD or mean >= IMAGE_MEAN + 10*IMAGE_MEAN_STD:
162
  return 0
163
+ # low contrast
164
  if std <= IMG_STD - 10*IMG_STD_STD:
165
  return 0
166
  status = saveImage(plate, file, outputDir)
176
  for shape in ['64_32', '128_64', '192_96']:
177
  os.mkdir(os.path.join(outputDir, shape))
178
  except OSError:
179
+ pass # path already exists
180
+ client = LocalCluster(n_workers=jobNum, threads_per_worker=5) # IO intensive, more threads
181
+ print('* number of workers:{}, \n* input dir:{}, \n* output dir:{}\n\n'.format(jobNum, inputDir, outputDir))
182
+ for subFolder in ['ccpd_green', 'ccpd_base', 'ccpd_db', 'ccpd_fn', 'ccpd_rotate', 'ccpd_tilt', 'ccpd_weather']:
183
  fileList = os.listdir(os.path.join(inputDir, subFolder))
184
  print('* {} images found in {}. Start processing ...'.format(len(fileList), subFolder))
185
  toDo = dbag.from_sequence(fileList, npartitions=jobNum*30).persist() # persist the bag in memory
186
  toDo = toDo.map(processImage, inputDir, outputDir, subFolder)
187
  pbar = ProgressBar(minimum=2.0)
188
+ pbar.register() # register all computations for better tracking
189
  result = toDo.compute()
190
  print('* image cropped: {}. Done ...'.format(sum(result)))
191
+ client.close() # shut down the cluster
192
 
193
 
194
  if __name__ == "__main__":
195
  parser = argparse.ArgumentParser(description=__doc__)
196
  add_arg = functools.partial(add_arguments, argparser=parser)
197
  add_arg('jobNum', int, 4, '处理图片的线程数')
198
+ add_arg('inputDir', str, 'datasets/CCPD2020', '输入图片目录')
199
+ add_arg('outputDir', str, 'datasets/CCPD2020_new', '保存图片目录')
200
  args = parser.parse_args()
201
  print_arguments(args)
202
  main(args)
utils/transforms.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import torch
4
+
5
+
6
+ def mod_crop(img, scale):
7
+ """Mod crop images, used during testing.
8
+
9
+ Args:
10
+ img (ndarray): Input image.
11
+ scale (int): Scale factor.
12
+
13
+ Returns:
14
+ ndarray: Result image.
15
+ """
16
+ img = img.copy()
17
+ if img.ndim in (2, 3):
18
+ h, w = img.shape[0], img.shape[1]
19
+ h_remainder, w_remainder = h % scale, w % scale
20
+ img = img[:h - h_remainder, :w - w_remainder, ...]
21
+ else:
22
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
23
+ return img
24
+
25
+
26
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
27
+ """Paired random crop. Support Numpy array and Tensor inputs.
28
+
29
+ It crops lists of lq and gt images with corresponding locations.
30
+
31
+ Args:
32
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
33
+ should have the same shape. If the input is an ndarray, it will
34
+ be transformed to a list containing itself.
35
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
36
+ should have the same shape. If the input is an ndarray, it will
37
+ be transformed to a list containing itself.
38
+ gt_patch_size (int): GT patch size.
39
+ scale (int): Scale factor.
40
+ gt_path (str): Path to ground-truth. Default: None.
41
+
42
+ Returns:
43
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
44
+ only have one element, just return ndarray.
45
+ """
46
+
47
+ if not isinstance(img_gts, list):
48
+ img_gts = [img_gts]
49
+ if not isinstance(img_lqs, list):
50
+ img_lqs = [img_lqs]
51
+
52
+ # determine input type: Numpy array or Tensor
53
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
54
+
55
+ if input_type == 'Tensor':
56
+ h_lq, w_lq = img_lqs[0].size()[-2:]
57
+ h_gt, w_gt = img_gts[0].size()[-2:]
58
+ else:
59
+ h_lq, w_lq = img_lqs[0].shape[0:2]
60
+ h_gt, w_gt = img_gts[0].shape[0:2]
61
+ lq_patch_size = gt_patch_size // scale
62
+
63
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
64
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
65
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
66
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
67
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
68
+ f'({lq_patch_size}, {lq_patch_size}). '
69
+ f'Please remove {gt_path}.')
70
+
71
+ # randomly choose top and left coordinates for lq patch
72
+ top = random.randint(0, h_lq - lq_patch_size)
73
+ left = random.randint(0, w_lq - lq_patch_size)
74
+
75
+ # crop lq patch
76
+ if input_type == 'Tensor':
77
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
78
+ else:
79
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
80
+
81
+ # crop corresponding gt patch
82
+ top_gt, left_gt = int(top * scale), int(left * scale)
83
+ if input_type == 'Tensor':
84
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
85
+ else:
86
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
87
+ if len(img_gts) == 1:
88
+ img_gts = img_gts[0]
89
+ if len(img_lqs) == 1:
90
+ img_lqs = img_lqs[0]
91
+ return img_gts, img_lqs
92
+
93
+
94
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
95
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
96
+
97
+ We use vertical flip and transpose for rotation implementation.
98
+ All the images in the list use the same augmentation.
99
+
100
+ Args:
101
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
102
+ is an ndarray, it will be transformed to a list.
103
+ hflip (bool): Horizontal flip. Default: True.
104
+ rotation (bool): Ratotation. Default: True.
105
+ flows (list[ndarray]: Flows to be augmented. If the input is an
106
+ ndarray, it will be transformed to a list.
107
+ Dimension is (h, w, 2). Default: None.
108
+ return_status (bool): Return the status of flip and rotation.
109
+ Default: False.
110
+
111
+ Returns:
112
+ list[ndarray] | ndarray: Augmented images and flows. If returned
113
+ results only have one element, just return ndarray.
114
+
115
+ """
116
+ hflip = hflip and random.random() < 0.5
117
+ vflip = rotation and random.random() < 0.5
118
+ rot90 = rotation and random.random() < 0.5
119
+
120
+ def _augment(img):
121
+ if hflip: # horizontal
122
+ cv2.flip(img, 1, img)
123
+ if vflip: # vertical
124
+ cv2.flip(img, 0, img)
125
+ if rot90:
126
+ img = img.transpose(1, 0, 2)
127
+ return img
128
+
129
+ def _augment_flow(flow):
130
+ if hflip: # horizontal
131
+ cv2.flip(flow, 1, flow)
132
+ flow[:, :, 0] *= -1
133
+ if vflip: # vertical
134
+ cv2.flip(flow, 0, flow)
135
+ flow[:, :, 1] *= -1
136
+ if rot90:
137
+ flow = flow.transpose(1, 0, 2)
138
+ flow = flow[:, :, [1, 0]]
139
+ return flow
140
+
141
+ if not isinstance(imgs, list):
142
+ imgs = [imgs]
143
+ imgs = [_augment(img) for img in imgs]
144
+ if len(imgs) == 1:
145
+ imgs = imgs[0]
146
+
147
+ if flows is not None:
148
+ if not isinstance(flows, list):
149
+ flows = [flows]
150
+ flows = [_augment_flow(flow) for flow in flows]
151
+ if len(flows) == 1:
152
+ flows = flows[0]
153
+ return imgs, flows
154
+ else:
155
+ if return_status:
156
+ return imgs, (hflip, vflip, rot90)
157
+ else:
158
+ return imgs
159
+
160
+
161
+ def img_rotate(img, angle, center=None, scale=1.0):
162
+ """Rotate image.
163
+
164
+ Args:
165
+ img (ndarray): Image to be rotated.
166
+ angle (float): Rotation angle in degrees. Positive values mean
167
+ counter-clockwise rotation.
168
+ center (tuple[int]): Rotation center. If the center is None,
169
+ initialize it as the center of the image. Default: None.
170
+ scale (float): Isotropic scale factor. Default: 1.0.
171
+ """
172
+ (h, w) = img.shape[:2]
173
+
174
+ if center is None:
175
+ center = (w // 2, h // 2)
176
+
177
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
178
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
179
+ return rotated_img
utils/utils.py CHANGED
@@ -2,6 +2,8 @@ import itertools
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import torch
 
 
5
  import distutils.util
6
 
7
  def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
@@ -57,4 +59,104 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
57
  default=default,
58
  type=type,
59
  help=help + ' 默认: %(default)s.',
60
- **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import torch
5
+ from torch.nn import functional as F
6
+ import cv2
7
  import distutils.util
8
 
9
  def show_result(num_epoch, G_net, imgs_lr, imgs_hr):
59
  default=default,
60
  type=type,
61
  help=help + ' 默认: %(default)s.',
62
+ **kwargs)
63
+
64
+ def filter2D(img, kernel):
65
+ """PyTorch version of cv2.filter2D
66
+
67
+ Args:
68
+ img (Tensor): (b, c, h, w)
69
+ kernel (Tensor): (b, k, k)
70
+ """
71
+ k = kernel.size(-1)
72
+ b, c, h, w = img.size()
73
+ if k % 2 == 1:
74
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
75
+ else:
76
+ raise ValueError('Wrong kernel size')
77
+
78
+ ph, pw = img.size()[-2:]
79
+
80
+ if kernel.size(0) == 1:
81
+ # apply the same kernel to all batch images
82
+ img = img.view(b * c, 1, ph, pw)
83
+ kernel = kernel.view(1, 1, k, k)
84
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
85
+ else:
86
+ img = img.view(1, b * c, ph, pw)
87
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
88
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
89
+
90
+
91
+ def usm_sharp(img, weight=0.5, radius=50, threshold=10):
92
+ """USM sharpening.
93
+
94
+ Input image: I; Blurry image: B.
95
+ 1. sharp = I + weight * (I - B)
96
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
97
+ 3. Blur mask:
98
+ 4. Out = Mask * sharp + (1 - Mask) * I
99
+
100
+
101
+ Args:
102
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
103
+ weight (float): Sharp weight. Default: 1.
104
+ radius (float): Kernel size of Gaussian blur. Default: 50.
105
+ threshold (int):
106
+ """
107
+ if radius % 2 == 0:
108
+ radius += 1
109
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
110
+ residual = img - blur
111
+ mask = np.abs(residual) * 255 > threshold
112
+ mask = mask.astype('float32')
113
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
114
+
115
+ sharp = img + weight * residual
116
+ sharp = np.clip(sharp, 0, 1)
117
+ return soft_mask * sharp + (1 - soft_mask) * img
118
+
119
+
120
+ class USMSharp(torch.nn.Module):
121
+
122
+ def __init__(self, radius=50, sigma=0):
123
+ super(USMSharp, self).__init__()
124
+ if radius % 2 == 0:
125
+ radius += 1
126
+ self.radius = radius
127
+ kernel = cv2.getGaussianKernel(radius, sigma)
128
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
129
+ self.register_buffer('kernel', kernel)
130
+
131
+ def forward(self, img, weight=0.5, threshold=10):
132
+ blur = filter2D(img, self.kernel)
133
+ residual = img - blur
134
+
135
+ mask = torch.abs(residual) * 255 > threshold
136
+ mask = mask.float()
137
+ soft_mask = filter2D(mask, self.kernel)
138
+ sharp = img + weight * residual
139
+ sharp = torch.clip(sharp, 0, 1)
140
+ return soft_mask * sharp + (1 - soft_mask) * img
141
+
142
+ class USMSharp_npy():
143
+
144
+ def __init__(self, radius=50, sigma=0):
145
+ super(USMSharp_npy, self).__init__()
146
+ if radius % 2 == 0:
147
+ radius += 1
148
+ self.radius = radius
149
+ kernel = cv2.getGaussianKernel(radius, sigma)
150
+ self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32)
151
+
152
+ def filt(self, img, weight=0.5, threshold=10):
153
+ blur = cv2.filter2D(img, -1, self.kernel)
154
+ residual = img - blur
155
+
156
+ mask = np.abs(residual) * 255 > threshold
157
+ mask = mask.astype(np.float32)
158
+ soft_mask = cv2.filter2D(mask, -1, self.kernel)
159
+ sharp = img + weight * residual
160
+ sharp = np.clip(sharp, 0, 1)
161
+ return soft_mask * sharp + (1 - soft_mask) * img
162
+
utils/utils_fit.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
  from tqdm import tqdm
3
 
4
- from .utils import show_result, get_lr
5
  from .utils_metrics import PSNR, SSIM
6
 
7
 
8
- def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer, BCE_loss, MSE_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
9
  G_total_loss = 0
10
  D_total_loss = 0
11
  G_total_PSNR = 0
@@ -28,33 +28,38 @@ def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_mo
28
  #-------------------------------------------------#
29
  D_optimizer.zero_grad()
30
 
31
- D_result = D_model_train(hr_images)
32
- D_real_loss = BCE_loss(D_result, y_real)
33
- D_real_loss.backward()
34
 
35
  G_result = G_model_train(lr_images)
36
- D_result = D_model_train(G_result).squeeze()
37
- D_fake_loss = BCE_loss(D_result, y_fake)
38
- D_fake_loss.backward()
 
 
 
 
39
 
40
  D_optimizer.step()
41
 
42
- D_train_loss = D_real_loss + D_fake_loss
43
-
44
  #-------------------------------------------------#
45
  # 训练生成器
46
  #-------------------------------------------------#
47
  G_optimizer.zero_grad()
48
-
49
  G_result = G_model_train(lr_images)
50
- image_loss = MSE_loss(G_result, hr_images)
51
 
52
- D_result = D_model_train(G_result).squeeze()
53
- adversarial_loss = BCE_loss(D_result, y_real)
 
 
 
 
 
54
 
55
- perception_loss = MSE_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
56
 
57
- G_train_loss = image_loss + 1e-3 * adversarial_loss + 2e-6 * perception_loss
58
 
59
  G_train_loss.backward()
60
  G_optimizer.step()
1
  import torch
2
  from tqdm import tqdm
3
 
4
+ from .utils import get_lr, show_result
5
  from .utils_metrics import PSNR, SSIM
6
 
7
 
8
+ def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
9
  G_total_loss = 0
10
  D_total_loss = 0
11
  G_total_PSNR = 0
28
  #-------------------------------------------------#
29
  D_optimizer.zero_grad()
30
 
31
+ D_result_r = D_model_train(hr_images)
 
 
32
 
33
  G_result = G_model_train(lr_images)
34
+ D_result_f = D_model_train(G_result).squeeze()
35
+ D_result_rf = D_result_r - D_result_f.mean()
36
+ D_result_fr = D_result_f - D_result_r.mean()
37
+ D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_real)
38
+ D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_fake)
39
+ D_train_loss = (D_train_loss_rf + D_train_loss_fr) / 2
40
+ D_train_loss.backward()
41
 
42
  D_optimizer.step()
43
 
 
 
44
  #-------------------------------------------------#
45
  # 训练生成器
46
  #-------------------------------------------------#
47
  G_optimizer.zero_grad()
48
+
49
  G_result = G_model_train(lr_images)
50
+ image_loss = L1_loss(G_result, hr_images)
51
 
52
+ D_result_r = D_model_train(hr_images)
53
+ D_result_f = D_model_train(G_result).squeeze()
54
+ D_result_rf = D_result_r - D_result_f.mean()
55
+ D_result_fr = D_result_f - D_result_r.mean()
56
+ D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_fake)
57
+ D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_real)
58
+ adversarial_loss = (D_train_loss_rf + D_train_loss_fr) / 2
59
 
60
+ perception_loss = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
61
 
62
+ G_train_loss = image_loss + 1e-1 * adversarial_loss + 1e-1 * perception_loss
63
 
64
  G_train_loss.backward()
65
  G_optimizer.step()