saber2022 commited on
Commit
5642269
·
1 Parent(s): b443508

Create upcunet_v3.py

Browse files
Files changed (1) hide show
  1. upcunet_v3.py +714 -0
upcunet_v3.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ import os, sys
5
+ import numpy as np
6
+
7
+ root_path = os.path.abspath('.')
8
+ sys.path.append(root_path)
9
+
10
+
11
+ class SEBlock(nn.Module):
12
+ def __init__(self, in_channels, reduction=8, bias=False):
13
+ super(SEBlock, self).__init__()
14
+ self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias)
15
+ self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias)
16
+
17
+ def forward(self, x):
18
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
19
+ x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
20
+ else:
21
+ x0 = torch.mean(x, dim=(2, 3), keepdim=True)
22
+ x0 = self.conv1(x0)
23
+ x0 = F.relu(x0, inplace=True)
24
+ x0 = self.conv2(x0)
25
+ x0 = torch.sigmoid(x0)
26
+ x = torch.mul(x, x0)
27
+ return x
28
+
29
+ def forward_mean(self, x, x0):
30
+ x0 = self.conv1(x0)
31
+ x0 = F.relu(x0, inplace=True)
32
+ x0 = self.conv2(x0)
33
+ x0 = torch.sigmoid(x0)
34
+ x = torch.mul(x, x0)
35
+ return x
36
+
37
+
38
+ class UNetConv(nn.Module):
39
+ def __init__(self, in_channels, mid_channels, out_channels, se):
40
+ super(UNetConv, self).__init__()
41
+ self.conv = nn.Sequential(
42
+ nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
43
+ nn.LeakyReLU(0.1, inplace=True),
44
+ nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
45
+ nn.LeakyReLU(0.1, inplace=True),
46
+ )
47
+ if se:
48
+ self.seblock = SEBlock(out_channels, reduction=8, bias=True)
49
+ else:
50
+ self.seblock = None
51
+
52
+ def forward(self, x):
53
+ z = self.conv(x)
54
+ if self.seblock is not None:
55
+ z = self.seblock(z)
56
+ return z
57
+
58
+
59
+ class UNet1(nn.Module):
60
+ def __init__(self, in_channels, out_channels, deconv):
61
+ super(UNet1, self).__init__()
62
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
63
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
64
+ self.conv2 = UNetConv(64, 128, 64, se=True)
65
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
66
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
67
+
68
+ if deconv:
69
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
70
+ else:
71
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
72
+
73
+ for m in self.modules():
74
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
75
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
76
+ elif isinstance(m, nn.Linear):
77
+ nn.init.normal_(m.weight, 0, 0.01)
78
+ if m.bias is not None:
79
+ nn.init.constant_(m.bias, 0)
80
+
81
+ def forward(self, x):
82
+ x1 = self.conv1(x)
83
+ x2 = self.conv1_down(x1)
84
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
85
+ x2 = self.conv2(x2)
86
+ x2 = self.conv2_up(x2)
87
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
88
+
89
+ x1 = F.pad(x1, (-4, -4, -4, -4))
90
+ x3 = self.conv3(x1 + x2)
91
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
92
+ z = self.conv_bottom(x3)
93
+ return z
94
+
95
+ def forward_a(self, x):
96
+ x1 = self.conv1(x)
97
+ x2 = self.conv1_down(x1)
98
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
99
+ x2 = self.conv2.conv(x2)
100
+ return x1, x2
101
+
102
+ def forward_b(self, x1, x2):
103
+ x2 = self.conv2_up(x2)
104
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
105
+
106
+ x1 = F.pad(x1, (-4, -4, -4, -4))
107
+ x3 = self.conv3(x1 + x2)
108
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
109
+ z = self.conv_bottom(x3)
110
+ return z
111
+
112
+
113
+ class UNet1x3(nn.Module):
114
+ def __init__(self, in_channels, out_channels, deconv):
115
+ super(UNet1x3, self).__init__()
116
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
117
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
118
+ self.conv2 = UNetConv(64, 128, 64, se=True)
119
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
120
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
121
+
122
+ if deconv:
123
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2)
124
+ else:
125
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
126
+
127
+ for m in self.modules():
128
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
129
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
130
+ elif isinstance(m, nn.Linear):
131
+ nn.init.normal_(m.weight, 0, 0.01)
132
+ if m.bias is not None:
133
+ nn.init.constant_(m.bias, 0)
134
+
135
+ def forward(self, x):
136
+ x1 = self.conv1(x)
137
+ x2 = self.conv1_down(x1)
138
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
139
+ x2 = self.conv2(x2)
140
+ x2 = self.conv2_up(x2)
141
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
142
+
143
+ x1 = F.pad(x1, (-4, -4, -4, -4))
144
+ x3 = self.conv3(x1 + x2)
145
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
146
+ z = self.conv_bottom(x3)
147
+ return z
148
+
149
+ def forward_a(self, x):
150
+ x1 = self.conv1(x)
151
+ x2 = self.conv1_down(x1)
152
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
153
+ x2 = self.conv2.conv(x2)
154
+ return x1, x2
155
+
156
+ def forward_b(self, x1, x2):
157
+ x2 = self.conv2_up(x2)
158
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
159
+
160
+ x1 = F.pad(x1, (-4, -4, -4, -4))
161
+ x3 = self.conv3(x1 + x2)
162
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
163
+ z = self.conv_bottom(x3)
164
+ return z
165
+
166
+
167
+ class UNet2(nn.Module):
168
+ def __init__(self, in_channels, out_channels, deconv):
169
+ super(UNet2, self).__init__()
170
+
171
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
172
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
173
+ self.conv2 = UNetConv(64, 64, 128, se=True)
174
+ self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
175
+ self.conv3 = UNetConv(128, 256, 128, se=True)
176
+ self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
177
+ self.conv4 = UNetConv(128, 64, 64, se=True)
178
+ self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
179
+ self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
180
+
181
+ if deconv:
182
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
183
+ else:
184
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
185
+
186
+ for m in self.modules():
187
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
188
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
189
+ elif isinstance(m, nn.Linear):
190
+ nn.init.normal_(m.weight, 0, 0.01)
191
+ if m.bias is not None:
192
+ nn.init.constant_(m.bias, 0)
193
+
194
+ def forward(self, x):
195
+ x1 = self.conv1(x)
196
+ x2 = self.conv1_down(x1)
197
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
198
+ x2 = self.conv2(x2)
199
+
200
+ x3 = self.conv2_down(x2)
201
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
202
+ x3 = self.conv3(x3)
203
+ x3 = self.conv3_up(x3)
204
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
205
+
206
+ x2 = F.pad(x2, (-4, -4, -4, -4))
207
+ x4 = self.conv4(x2 + x3)
208
+ x4 = self.conv4_up(x4)
209
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
210
+
211
+ x1 = F.pad(x1, (-16, -16, -16, -16))
212
+ x5 = self.conv5(x1 + x4)
213
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
214
+
215
+ z = self.conv_bottom(x5)
216
+ return z
217
+
218
+ def forward_a(self, x): # conv234结尾有se
219
+ x1 = self.conv1(x)
220
+ x2 = self.conv1_down(x1)
221
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
222
+ x2 = self.conv2.conv(x2)
223
+ return x1, x2
224
+
225
+ def forward_b(self, x2): # conv234结尾有se
226
+ x3 = self.conv2_down(x2)
227
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
228
+ x3 = self.conv3.conv(x3)
229
+ return x3
230
+
231
+ def forward_c(self, x2, x3): # conv234结尾有se
232
+ x3 = self.conv3_up(x3)
233
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
234
+
235
+ x2 = F.pad(x2, (-4, -4, -4, -4))
236
+ x4 = self.conv4.conv(x2 + x3)
237
+ return x4
238
+
239
+ def forward_d(self, x1, x4): # conv234结尾有se
240
+ x4 = self.conv4_up(x4)
241
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
242
+
243
+ x1 = F.pad(x1, (-16, -16, -16, -16))
244
+ x5 = self.conv5(x1 + x4)
245
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
246
+
247
+ z = self.conv_bottom(x5)
248
+ return z
249
+
250
+
251
+ class UpCunet2x(nn.Module): # 完美tile,全程无损
252
+ def __init__(self, in_channels=3, out_channels=3):
253
+ super(UpCunet2x, self).__init__()
254
+ self.unet1 = UNet1(in_channels, out_channels, deconv=True)
255
+ self.unet2 = UNet2(in_channels, out_channels, deconv=False)
256
+
257
+ def forward(self, x, tile_mode): # 1.7G
258
+ n, c, h0, w0 = x.shape
259
+ if (tile_mode == 0): # 不tile
260
+ ph = ((h0 - 1) // 2 + 1) * 2
261
+ pw = ((w0 - 1) // 2 + 1) * 2
262
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') # 需要保证被2整除
263
+ x = self.unet1.forward(x)
264
+ x0 = self.unet2.forward(x)
265
+ x1 = F.pad(x, (-20, -20, -20, -20))
266
+ x = torch.add(x0, x1)
267
+ if (w0 != pw or h0 != ph): x = x[:, :, :h0 * 2, :w0 * 2]
268
+ return x
269
+ elif (tile_mode == 1): # 对长边减半
270
+ if (w0 >= h0):
271
+ crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
272
+ crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
273
+ else:
274
+ crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
275
+ crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
276
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
277
+ elif (tile_mode == 2): # hw都减半
278
+ crop_size = (((h0 - 1) // 4 * 4 + 4) // 2, ((w0 - 1) // 4 * 4 + 4) // 2) # 5.6G
279
+ elif (tile_mode == 3): # hw都三分之一
280
+ crop_size = (((h0 - 1) // 6 * 6 + 6) // 3, ((w0 - 1) // 6 * 6 + 6) // 3) # 4.2G
281
+ elif (tile_mode == 4): # hw都四分���一
282
+ crop_size = (((h0 - 1) // 8 * 8 + 8) // 4, ((w0 - 1) // 8 * 8 + 8) // 4) # 3.7G
283
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
284
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
285
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect')
286
+ n, c, h, w = x.shape
287
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
288
+ if ("Half" in x.type()):
289
+ se_mean0 = se_mean0.half()
290
+ n_patch = 0
291
+ tmp_dict = {}
292
+ opt_res_dict = {}
293
+ for i in range(0, h - 36, crop_size[0]):
294
+ tmp_dict[i] = {}
295
+ for j in range(0, w - 36, crop_size[1]):
296
+ x_crop = x[:, :, i:i + crop_size[0] + 36, j:j + crop_size[1] + 36]
297
+ n, c1, h1, w1 = x_crop.shape
298
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
299
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
300
+ tmp_se_mean = torch.mean(x_crop.float(), dim=(2, 3), keepdim=True).half()
301
+ else:
302
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
303
+ se_mean0 += tmp_se_mean
304
+ n_patch += 1
305
+ tmp_dict[i][j] = (tmp0, x_crop)
306
+ se_mean0 /= n_patch
307
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
308
+ if ("Half" in x.type()):
309
+ se_mean1 = se_mean1.half()
310
+ for i in range(0, h - 36, crop_size[0]):
311
+ for j in range(0, w - 36, crop_size[1]):
312
+ tmp0, x_crop = tmp_dict[i][j]
313
+ x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
314
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
315
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
316
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
317
+ tmp_se_mean = torch.mean(tmp_x2.float(), dim=(2, 3), keepdim=True).half()
318
+ else:
319
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
320
+ se_mean1 += tmp_se_mean
321
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
322
+ se_mean1 /= n_patch
323
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
324
+ if ("Half" in x.type()):
325
+ se_mean0 = se_mean0.half()
326
+ for i in range(0, h - 36, crop_size[0]):
327
+ for j in range(0, w - 36, crop_size[1]):
328
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
329
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
330
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
331
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
332
+ tmp_se_mean = torch.mean(tmp_x3.float(), dim=(2, 3), keepdim=True).half()
333
+ else:
334
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
335
+ se_mean0 += tmp_se_mean
336
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
337
+ se_mean0 /= n_patch
338
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
339
+ if ("Half" in x.type()):
340
+ se_mean1 = se_mean1.half()
341
+ for i in range(0, h - 36, crop_size[0]):
342
+ for j in range(0, w - 36, crop_size[1]):
343
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
344
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
345
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
346
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
347
+ tmp_se_mean = torch.mean(tmp_x4.float(), dim=(2, 3), keepdim=True).half()
348
+ else:
349
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
350
+ se_mean1 += tmp_se_mean
351
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
352
+ se_mean1 /= n_patch
353
+ for i in range(0, h - 36, crop_size[0]):
354
+ opt_res_dict[i] = {}
355
+ for j in range(0, w - 36, crop_size[1]):
356
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
357
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
358
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
359
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
360
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
361
+ opt_res_dict[i][j] = x_crop
362
+ del tmp_dict
363
+ torch.cuda.empty_cache()
364
+ res = torch.zeros((n, c, h * 2 - 72, w * 2 - 72)).to(x.device)
365
+ if ("Half" in x.type()):
366
+ res = res.half()
367
+ for i in range(0, h - 36, crop_size[0]):
368
+ for j in range(0, w - 36, crop_size[1]):
369
+ res[:, :, i * 2:i * 2 + h1 * 2 - 72, j * 2:j * 2 + w1 * 2 - 72] = opt_res_dict[i][j]
370
+ del opt_res_dict
371
+ torch.cuda.empty_cache()
372
+ if (w0 != pw or h0 != ph): res = res[:, :, :h0 * 2, :w0 * 2]
373
+ return res #
374
+
375
+
376
+ class UpCunet3x(nn.Module): # 完美tile,全程无损
377
+ def __init__(self, in_channels=3, out_channels=3):
378
+ super(UpCunet3x, self).__init__()
379
+ self.unet1 = UNet1x3(in_channels, out_channels, deconv=True)
380
+ self.unet2 = UNet2(in_channels, out_channels, deconv=False)
381
+
382
+ def forward(self, x, tile_mode): # 1.7G
383
+ n, c, h0, w0 = x.shape
384
+ if (tile_mode == 0): # 不tile
385
+ ph = ((h0 - 1) // 4 + 1) * 4
386
+ pw = ((w0 - 1) // 4 + 1) * 4
387
+ x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), 'reflect') # 需要保证被2整除
388
+ x = self.unet1.forward(x)
389
+ x0 = self.unet2.forward(x)
390
+ x1 = F.pad(x, (-20, -20, -20, -20))
391
+ x = torch.add(x0, x1)
392
+ if (w0 != pw or h0 != ph): x = x[:, :, :h0 * 3, :w0 * 3]
393
+ return x
394
+ elif (tile_mode == 1): # 对长边减半
395
+ if (w0 >= h0):
396
+ crop_size_w = ((w0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
397
+ crop_size_h = (h0 - 1) // 4 * 4 + 4 # 能被4整除
398
+ else:
399
+ crop_size_h = ((h0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
400
+ crop_size_w = (w0 - 1) // 4 * 4 + 4 # 能被4整除
401
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
402
+ elif (tile_mode == 2): # hw都减半
403
+ crop_size = (((h0 - 1) // 8 * 8 + 8) // 2, ((w0 - 1) // 8 * 8 + 8) // 2) # 5.6G
404
+ elif (tile_mode == 3): # hw都三分之一
405
+ crop_size = (((h0 - 1) // 12 * 12 + 12) // 3, ((w0 - 1) // 12 * 12 + 12) // 3) # 4.2G
406
+ elif (tile_mode == 4): # hw都四分之一
407
+ crop_size = (((h0 - 1) // 16 * 16 + 16) // 4, ((w0 - 1) // 16 * 16 + 16) // 4) # 3.7G
408
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
409
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
410
+ x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), 'reflect')
411
+ n, c, h, w = x.shape
412
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
413
+ if ("Half" in x.type()):
414
+ se_mean0 = se_mean0.half()
415
+ n_patch = 0
416
+ tmp_dict = {}
417
+ opt_res_dict = {}
418
+ for i in range(0, h - 28, crop_size[0]):
419
+ tmp_dict[i] = {}
420
+ for j in range(0, w - 28, crop_size[1]):
421
+ x_crop = x[:, :, i:i + crop_size[0] + 28, j:j + crop_size[1] + 28]
422
+ n, c1, h1, w1 = x_crop.shape
423
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
424
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
425
+ tmp_se_mean = torch.mean(x_crop.float(), dim=(2, 3), keepdim=True).half()
426
+ else:
427
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
428
+ se_mean0 += tmp_se_mean
429
+ n_patch += 1
430
+ tmp_dict[i][j] = (tmp0, x_crop)
431
+ se_mean0 /= n_patch
432
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
433
+ if ("Half" in x.type()):
434
+ se_mean1 = se_mean1.half()
435
+ for i in range(0, h - 28, crop_size[0]):
436
+ for j in range(0, w - 28, crop_size[1]):
437
+ tmp0, x_crop = tmp_dict[i][j]
438
+ x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
439
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
440
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
441
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
442
+ tmp_se_mean = torch.mean(tmp_x2.float(), dim=(2, 3), keepdim=True).half()
443
+ else:
444
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
445
+ se_mean1 += tmp_se_mean
446
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
447
+ se_mean1 /= n_patch
448
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
449
+ if ("Half" in x.type()):
450
+ se_mean0 = se_mean0.half()
451
+ for i in range(0, h - 28, crop_size[0]):
452
+ for j in range(0, w - 28, crop_size[1]):
453
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
454
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
455
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
456
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
457
+ tmp_se_mean = torch.mean(tmp_x3.float(), dim=(2, 3), keepdim=True).half()
458
+ else:
459
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
460
+ se_mean0 += tmp_se_mean
461
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
462
+ se_mean0 /= n_patch
463
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
464
+ if ("Half" in x.type()):
465
+ se_mean1 = se_mean1.half()
466
+ for i in range(0, h - 28, crop_size[0]):
467
+ for j in range(0, w - 28, crop_size[1]):
468
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
469
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
470
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
471
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
472
+ tmp_se_mean = torch.mean(tmp_x4.float(), dim=(2, 3), keepdim=True).half()
473
+ else:
474
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
475
+ se_mean1 += tmp_se_mean
476
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
477
+ se_mean1 /= n_patch
478
+ for i in range(0, h - 28, crop_size[0]):
479
+ opt_res_dict[i] = {}
480
+ for j in range(0, w - 28, crop_size[1]):
481
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
482
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
483
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
484
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
485
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
486
+ opt_res_dict[i][j] = x_crop #
487
+ del tmp_dict
488
+ torch.cuda.empty_cache()
489
+ res = torch.zeros((n, c, h * 3 - 84, w * 3 - 84)).to(x.device)
490
+ if ("Half" in x.type()):
491
+ res = res.half()
492
+ for i in range(0, h - 28, crop_size[0]):
493
+ for j in range(0, w - 28, crop_size[1]):
494
+ res[:, :, i * 3:i * 3 + h1 * 3 - 84, j * 3:j * 3 + w1 * 3 - 84] = opt_res_dict[i][j]
495
+ del opt_res_dict
496
+ torch.cuda.empty_cache()
497
+ if (w0 != pw or h0 != ph): res = res[:, :, :h0 * 3, :w0 * 3]
498
+ return res
499
+
500
+
501
+ class UpCunet4x(nn.Module): # 完美tile,全程无损
502
+ def __init__(self, in_channels=3, out_channels=3):
503
+ super(UpCunet4x, self).__init__()
504
+ self.unet1 = UNet1(in_channels, 64, deconv=True)
505
+ self.unet2 = UNet2(64, 64, deconv=False)
506
+ self.ps = nn.PixelShuffle(2)
507
+ self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
508
+
509
+ def forward(self, x, tile_mode):
510
+ n, c, h0, w0 = x.shape
511
+ x00 = x
512
+ if (tile_mode == 0): # 不tile
513
+ ph = ((h0 - 1) // 2 + 1) * 2
514
+ pw = ((w0 - 1) // 2 + 1) * 2
515
+ x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), 'reflect') # 需要保证被2整除
516
+ x = self.unet1.forward(x)
517
+ x0 = self.unet2.forward(x)
518
+ x1 = F.pad(x, (-20, -20, -20, -20))
519
+ x = torch.add(x0, x1)
520
+ x = self.conv_final(x)
521
+ x = F.pad(x, (-1, -1, -1, -1))
522
+ x = self.ps(x)
523
+ if (w0 != pw or h0 != ph): x = x[:, :, :h0 * 4, :w0 * 4]
524
+ x += F.interpolate(x00, scale_factor=4, mode='nearest')
525
+ return x
526
+ elif (tile_mode == 1): # 对长边减半
527
+ if (w0 >= h0):
528
+ crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
529
+ crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
530
+ else:
531
+ crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
532
+ crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
533
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
534
+ elif (tile_mode == 2): # hw都减半
535
+ crop_size = (((h0 - 1) // 4 * 4 + 4) // 2, ((w0 - 1) // 4 * 4 + 4) // 2) # 5.6G
536
+ elif (tile_mode == 3): # hw都三分之一
537
+ crop_size = (((h0 - 1) // 6 * 6 + 6) // 3, ((w0 - 1) // 6 * 6 + 6) // 3) # 4.1G
538
+ elif (tile_mode == 4): # hw都四分之一
539
+ crop_size = (((h0 - 1) // 8 * 8 + 8) // 4, ((w0 - 1) // 8 * 8 + 8) // 4) # 3.7G
540
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
541
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
542
+ x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), 'reflect')
543
+ n, c, h, w = x.shape
544
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
545
+ if ("Half" in x.type()):
546
+ se_mean0 = se_mean0.half()
547
+ n_patch = 0
548
+ tmp_dict = {}
549
+ opt_res_dict = {}
550
+ for i in range(0, h - 38, crop_size[0]):
551
+ tmp_dict[i] = {}
552
+ for j in range(0, w - 38, crop_size[1]):
553
+ x_crop = x[:, :, i:i + crop_size[0] + 38, j:j + crop_size[1] + 38]
554
+ n, c1, h1, w1 = x_crop.shape
555
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
556
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
557
+ tmp_se_mean = torch.mean(x_crop.float(), dim=(2, 3), keepdim=True).half()
558
+ else:
559
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
560
+ se_mean0 += tmp_se_mean
561
+ n_patch += 1
562
+ tmp_dict[i][j] = (tmp0, x_crop)
563
+ se_mean0 /= n_patch
564
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
565
+ if ("Half" in x.type()):
566
+ se_mean1 = se_mean1.half()
567
+ for i in range(0, h - 38, crop_size[0]):
568
+ for j in range(0, w - 38, crop_size[1]):
569
+ tmp0, x_crop = tmp_dict[i][j]
570
+ x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0)
571
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
572
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
573
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
574
+ tmp_se_mean = torch.mean(tmp_x2.float(), dim=(2, 3), keepdim=True).half()
575
+ else:
576
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
577
+ se_mean1 += tmp_se_mean
578
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
579
+ se_mean1 /= n_patch
580
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
581
+ if ("Half" in x.type()):
582
+ se_mean0 = se_mean0.half()
583
+ for i in range(0, h - 38, crop_size[0]):
584
+ for j in range(0, w - 38, crop_size[1]):
585
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
586
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1)
587
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
588
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
589
+ tmp_se_mean = torch.mean(tmp_x3.float(), dim=(2, 3), keepdim=True).half()
590
+ else:
591
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
592
+ se_mean0 += tmp_se_mean
593
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
594
+ se_mean0 /= n_patch
595
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
596
+ if ("Half" in x.type()):
597
+ se_mean1 = se_mean1.half()
598
+ for i in range(0, h - 38, crop_size[0]):
599
+ for j in range(0, w - 38, crop_size[1]):
600
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
601
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0)
602
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
603
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
604
+ tmp_se_mean = torch.mean(tmp_x4.float(), dim=(2, 3), keepdim=True).half()
605
+ else:
606
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
607
+ se_mean1 += tmp_se_mean
608
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
609
+ se_mean1 /= n_patch
610
+ for i in range(0, h - 38, crop_size[0]):
611
+ opt_res_dict[i] = {}
612
+ for j in range(0, w - 38, crop_size[1]):
613
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
614
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1)
615
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
616
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
617
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
618
+ x_crop = self.conv_final(x_crop)
619
+ x_crop = F.pad(x_crop, (-1, -1, -1, -1))
620
+ x_crop = self.ps(x_crop)
621
+ opt_res_dict[i][j] = x_crop
622
+ del tmp_dict
623
+ torch.cuda.empty_cache()
624
+ res = torch.zeros((n, c, h * 4 - 152, w * 4 - 152)).to(x.device)
625
+ if ("Half" in x.type()):
626
+ res = res.half()
627
+ for i in range(0, h - 38, crop_size[0]):
628
+ for j in range(0, w - 38, crop_size[1]):
629
+ # print(opt_res_dict[i][j].shape,res[:, :, i * 4:i * 4 + h1 * 4 - 144, j * 4:j * 4 + w1 * 4 - 144].shape)
630
+ res[:, :, i * 4:i * 4 + h1 * 4 - 152, j * 4:j * 4 + w1 * 4 - 152] = opt_res_dict[i][j]
631
+ del opt_res_dict
632
+ torch.cuda.empty_cache()
633
+ if (w0 != pw or h0 != ph): res = res[:, :, :h0 * 4, :w0 * 4]
634
+ res += F.interpolate(x00, scale_factor=4, mode='nearest')
635
+ return res #
636
+
637
+
638
+ class RealWaifuUpScaler(object):
639
+ def __init__(self, scale, weight_path, half, device):
640
+ weight = torch.load(weight_path, map_location="cpu")
641
+ self.model = eval("UpCunet%sx" % scale)()
642
+ if (half == True):
643
+ self.model = self.model.half().to(device)
644
+ else:
645
+ self.model = self.model.to(device)
646
+ self.model.load_state_dict(weight, strict=True)
647
+ self.model.eval()
648
+ self.half = half
649
+ self.device = device
650
+
651
+ def np2tensor(self, np_frame):
652
+ if (self.half == False):
653
+ return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).to(self.device).float() / 255
654
+ else:
655
+ return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).to(self.device).half() / 255
656
+
657
+ def tensor2np(self, tensor):
658
+ if (self.half == False):
659
+ return (
660
+ np.transpose((tensor.data.squeeze() * 255.0).round().clamp_(0, 255).byte().cpu().numpy(), (1, 2, 0)))
661
+ else:
662
+ return (np.transpose((tensor.data.squeeze().float() * 255.0).round().clamp_(0, 255).byte().cpu().numpy(),
663
+ (1, 2, 0)))
664
+
665
+ def __call__(self, frame, tile_mode):
666
+ with torch.no_grad():
667
+ tensor = self.np2tensor(frame)
668
+ result = self.tensor2np(self.model(tensor, tile_mode))
669
+ return result
670
+
671
+
672
+ if __name__ == "__main__":
673
+ ###########inference_img
674
+ import time, cv2, sys
675
+ from time import time as ttime
676
+
677
+ for weight_path, scale in [("weights_v3/up2x-latest-denoise3x.pth", 2), ("weights_v3/up3x-latest-denoise3x.pth", 3),
678
+ ("weights_v3/up4x-latest-denoise3x.pth", 4)]:
679
+ for tile_mode in [0, 1, 2, 3, 4]:
680
+ upscaler2x = RealWaifuUpScaler(scale, weight_path, half=True, device="cuda:0")
681
+ input_dir = "%s/input_dir1" % root_path
682
+ output_dir = "%s/opt-dir-all-test" % root_path
683
+ os.makedirs(output_dir, exist_ok=True)
684
+ for name in os.listdir(input_dir):
685
+ print(name)
686
+ tmp = name.split(".")
687
+ inp_path = os.path.join(input_dir, name)
688
+ suffix = tmp[-1]
689
+ prefix = ".".join(tmp[:-1])
690
+ tmp_path = os.path.join(root_path, "tmp", "%s.%s" % (int(time.time() * 1000000), suffix))
691
+ print(inp_path, tmp_path)
692
+ # 支持中文路径
693
+ # os.link(inp_path, tmp_path)#win用硬链接
694
+ os.symlink(inp_path, tmp_path) # linux用软链接
695
+ frame = cv2.imread(tmp_path)[:, :, [2, 1, 0]]
696
+ t0 = ttime()
697
+ result = upscaler2x(frame, tile_mode=tile_mode)[:, :, ::-1]
698
+ t1 = ttime()
699
+ print(prefix, "done", t1 - t0)
700
+ tmp_opt_path = os.path.join(root_path, "tmp", "%s.%s" % (int(time.time() * 1000000), suffix))
701
+ cv2.imwrite(tmp_opt_path, result)
702
+ n = 0
703
+ while (1):
704
+ if (n == 0):
705
+ suffix = "_%sx_tile%s.png" % (scale, tile_mode)
706
+ else:
707
+ suffix = "_%sx_tile%s_%s.png" % (scale, tile_mode, n) #
708
+ if (os.path.exists(os.path.join(output_dir, prefix + suffix)) == False):
709
+ break
710
+ else:
711
+ n += 1
712
+ final_opt_path = os.path.join(output_dir, prefix + suffix)
713
+ os.rename(tmp_opt_path, final_opt_path)
714
+ os.remove(tmp_path)