whoismygrandson commited on
Commit
83c904a
1 Parent(s): 8a8db9b

Add application file

Browse files
.gitattributes CHANGED
@@ -1,34 +1,27 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
  *.xz filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,48 @@
1
  ---
2
- title: Cf
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.9.1
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Real Cascade U-Nets for Anime Image Super Resolution
3
+ emoji: 👀
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
 
7
  app_file: app.py
8
+ pinned: true
9
+ license: mit
10
  ---
11
 
12
+ > From <https://github.com/bilibili/ailab/tree/main/Real-CUGAN>
13
+
14
+ # Configuration
15
+
16
+ `title`: _string_
17
+ Display title for the Space
18
+
19
+ `emoji`: _string_
20
+ Space emoji (emoji-only character allowed)
21
+
22
+ `colorFrom`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `colorTo`: _string_
26
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
27
+
28
+ `sdk`: _string_
29
+ Can be either `gradio`, `streamlit`, or `static`
30
+
31
+ `sdk_version` : _string_
32
+ Only applicable for `streamlit` SDK.
33
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
34
+
35
+ `app_file`: _string_
36
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
37
+ Path is relative to the root of the repository.
38
+
39
+ `models`: _List[string]_
40
+ HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
41
+ Will be parsed automatically from your code if not specified here.
42
+
43
+ `datasets`: _List[string]_
44
+ HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
45
+ Will be parsed automatically from your code if not specified here.
46
+
47
+ `pinned`: _boolean_
48
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import IntEnum
3
+ from pathlib import Path
4
+ from tempfile import mktemp
5
+ from typing import IO, Dict, Type
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from gradio import Interface, inputs, outputs
13
+
14
+ DEVICE = "gpu"
15
+
16
+ WEIGHTS_PATH = Path(__file__).parent / "weights"
17
+
18
+ AVALIABLE_WEIGHTS = {
19
+ basename: path
20
+ for basename, ext in (
21
+ os.path.splitext(filename) for filename in os.listdir(WEIGHTS_PATH)
22
+ )
23
+ if (path := WEIGHTS_PATH / (basename + ext)).is_file() and ext.endswith("pth")
24
+ }
25
+
26
+
27
+ class ScaleMode(IntEnum):
28
+ up2x = 2
29
+ up3x = 3
30
+ up4x = 4
31
+
32
+
33
+ class TileMode(IntEnum):
34
+ full = 0
35
+ half = 1
36
+ quarter = 2
37
+ ninth = 3
38
+ sixteenth = 4
39
+
40
+
41
+ class SEBlock(nn.Module):
42
+ def __init__(self, in_channels, reduction=8, bias=False):
43
+ super(SEBlock, self).__init__()
44
+ self.conv1 = nn.Conv2d(
45
+ in_channels, in_channels // reduction, 1, 1, 0, bias=bias
46
+ )
47
+ self.conv2 = nn.Conv2d(
48
+ in_channels // reduction, in_channels, 1, 1, 0, bias=bias
49
+ )
50
+
51
+ def forward(self, x):
52
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
53
+ x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
54
+ else:
55
+ x0 = torch.mean(x, dim=(2, 3), keepdim=True)
56
+ x0 = self.conv1(x0)
57
+ x0 = F.relu(x0, inplace=True)
58
+ x0 = self.conv2(x0)
59
+ x0 = torch.sigmoid(x0)
60
+ x = torch.mul(x, x0)
61
+ return x
62
+
63
+ def forward_mean(self, x, x0):
64
+ x0 = self.conv1(x0)
65
+ x0 = F.relu(x0, inplace=True)
66
+ x0 = self.conv2(x0)
67
+ x0 = torch.sigmoid(x0)
68
+ x = torch.mul(x, x0)
69
+ return x
70
+
71
+
72
+ class UNetConv(nn.Module):
73
+ def __init__(self, in_channels, mid_channels, out_channels, se):
74
+ super(UNetConv, self).__init__()
75
+ self.conv = nn.Sequential(
76
+ nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
77
+ nn.LeakyReLU(0.1, inplace=True),
78
+ nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
79
+ nn.LeakyReLU(0.1, inplace=True),
80
+ )
81
+ if se:
82
+ self.seblock = SEBlock(out_channels, reduction=8, bias=True)
83
+ else:
84
+ self.seblock = None
85
+
86
+ def forward(self, x):
87
+ z = self.conv(x)
88
+ if self.seblock is not None:
89
+ z = self.seblock(z)
90
+ return z
91
+
92
+
93
+ class UNet1(nn.Module):
94
+ def __init__(self, in_channels, out_channels, deconv):
95
+ super(UNet1, self).__init__()
96
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
97
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
98
+ self.conv2 = UNetConv(64, 128, 64, se=True)
99
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
100
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
101
+
102
+ if deconv:
103
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
104
+ else:
105
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
109
+ nn.init.kaiming_normal_(
110
+ m.weight, mode="fan_out", nonlinearity="relu")
111
+ elif isinstance(m, nn.Linear):
112
+ nn.init.normal_(m.weight, 0, 0.01)
113
+ if m.bias is not None:
114
+ nn.init.constant_(m.bias, 0)
115
+
116
+ def forward(self, x):
117
+ x1 = self.conv1(x)
118
+ x2 = self.conv1_down(x1)
119
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
120
+ x2 = self.conv2(x2)
121
+ x2 = self.conv2_up(x2)
122
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
123
+
124
+ x1 = F.pad(x1, (-4, -4, -4, -4))
125
+ x3 = self.conv3(x1 + x2)
126
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
127
+ z = self.conv_bottom(x3)
128
+ return z
129
+
130
+ def forward_a(self, x):
131
+ x1 = self.conv1(x)
132
+ x2 = self.conv1_down(x1)
133
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
134
+ x2 = self.conv2.conv(x2)
135
+ return x1, x2
136
+
137
+ def forward_b(self, x1, x2):
138
+ x2 = self.conv2_up(x2)
139
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
140
+
141
+ x1 = F.pad(x1, (-4, -4, -4, -4))
142
+ x3 = self.conv3(x1 + x2)
143
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
144
+ z = self.conv_bottom(x3)
145
+ return z
146
+
147
+
148
+ class UNet1x3(nn.Module):
149
+ def __init__(self, in_channels, out_channels, deconv):
150
+ super(UNet1x3, self).__init__()
151
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
152
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
153
+ self.conv2 = UNetConv(64, 128, 64, se=True)
154
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
155
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
156
+
157
+ if deconv:
158
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2)
159
+ else:
160
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
161
+
162
+ for m in self.modules():
163
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
164
+ nn.init.kaiming_normal_(
165
+ m.weight, mode="fan_out", nonlinearity="relu")
166
+ elif isinstance(m, nn.Linear):
167
+ nn.init.normal_(m.weight, 0, 0.01)
168
+ if m.bias is not None:
169
+ nn.init.constant_(m.bias, 0)
170
+
171
+ def forward(self, x):
172
+ x1 = self.conv1(x)
173
+ x2 = self.conv1_down(x1)
174
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
175
+ x2 = self.conv2(x2)
176
+ x2 = self.conv2_up(x2)
177
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
178
+
179
+ x1 = F.pad(x1, (-4, -4, -4, -4))
180
+ x3 = self.conv3(x1 + x2)
181
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
182
+ z = self.conv_bottom(x3)
183
+ return z
184
+
185
+ def forward_a(self, x):
186
+ x1 = self.conv1(x)
187
+ x2 = self.conv1_down(x1)
188
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
189
+ x2 = self.conv2.conv(x2)
190
+ return x1, x2
191
+
192
+ def forward_b(self, x1, x2):
193
+ x2 = self.conv2_up(x2)
194
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
195
+
196
+ x1 = F.pad(x1, (-4, -4, -4, -4))
197
+ x3 = self.conv3(x1 + x2)
198
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
199
+ z = self.conv_bottom(x3)
200
+ return z
201
+
202
+
203
+ class UNet2(nn.Module):
204
+ def __init__(self, in_channels, out_channels, deconv):
205
+ super(UNet2, self).__init__()
206
+
207
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
208
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
209
+ self.conv2 = UNetConv(64, 64, 128, se=True)
210
+ self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
211
+ self.conv3 = UNetConv(128, 256, 128, se=True)
212
+ self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
213
+ self.conv4 = UNetConv(128, 64, 64, se=True)
214
+ self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
215
+ self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
216
+
217
+ if deconv:
218
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
219
+ else:
220
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
221
+
222
+ for m in self.modules():
223
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
224
+ nn.init.kaiming_normal_(
225
+ m.weight, mode="fan_out", nonlinearity="relu")
226
+ elif isinstance(m, nn.Linear):
227
+ nn.init.normal_(m.weight, 0, 0.01)
228
+ if m.bias is not None:
229
+ nn.init.constant_(m.bias, 0)
230
+
231
+ def forward(self, x):
232
+ x1 = self.conv1(x)
233
+ x2 = self.conv1_down(x1)
234
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
235
+ x2 = self.conv2(x2)
236
+
237
+ x3 = self.conv2_down(x2)
238
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
239
+ x3 = self.conv3(x3)
240
+ x3 = self.conv3_up(x3)
241
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
242
+
243
+ x2 = F.pad(x2, (-4, -4, -4, -4))
244
+ x4 = self.conv4(x2 + x3)
245
+ x4 = self.conv4_up(x4)
246
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
247
+
248
+ x1 = F.pad(x1, (-16, -16, -16, -16))
249
+ x5 = self.conv5(x1 + x4)
250
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
251
+
252
+ z = self.conv_bottom(x5)
253
+ return z
254
+
255
+ def forward_a(self, x): # conv234结尾有se
256
+ x1 = self.conv1(x)
257
+ x2 = self.conv1_down(x1)
258
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
259
+ x2 = self.conv2.conv(x2)
260
+ return x1, x2
261
+
262
+ def forward_b(self, x2): # conv234结尾有se
263
+ x3 = self.conv2_down(x2)
264
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
265
+ x3 = self.conv3.conv(x3)
266
+ return x3
267
+
268
+ def forward_c(self, x2, x3): # conv234结尾有se
269
+ x3 = self.conv3_up(x3)
270
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
271
+
272
+ x2 = F.pad(x2, (-4, -4, -4, -4))
273
+ x4 = self.conv4.conv(x2 + x3)
274
+ return x4
275
+
276
+ def forward_d(self, x1, x4): # conv234结尾有se
277
+ x4 = self.conv4_up(x4)
278
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
279
+
280
+ x1 = F.pad(x1, (-16, -16, -16, -16))
281
+ x5 = self.conv5(x1 + x4)
282
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
283
+
284
+ z = self.conv_bottom(x5)
285
+ return z
286
+
287
+
288
+ class UpCunet2x(nn.Module): # 完美tile,全程无损
289
+ def __init__(self, in_channels=3, out_channels=3):
290
+ super(UpCunet2x, self).__init__()
291
+ self.unet1 = UNet1(in_channels, out_channels, deconv=True)
292
+ self.unet2 = UNet2(in_channels, out_channels, deconv=False)
293
+
294
+ def forward(self, x, tile_mode): # 1.7G
295
+ n, c, h0, w0 = x.shape
296
+ if tile_mode == 0: # 不tile
297
+ ph = ((h0 - 1) // 2 + 1) * 2
298
+ pw = ((w0 - 1) // 2 + 1) * 2
299
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0),
300
+ "reflect") # 需要保证被2整除
301
+ x = self.unet1.forward(x)
302
+ x0 = self.unet2.forward(x)
303
+ x1 = F.pad(x, (-20, -20, -20, -20))
304
+ x = torch.add(x0, x1)
305
+ if w0 != pw or h0 != ph:
306
+ x = x[:, :, : h0 * 2, : w0 * 2]
307
+ return x
308
+ elif tile_mode == 1: # 对长边减半
309
+ if w0 >= h0:
310
+ crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
311
+ crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
312
+ else:
313
+ crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
314
+ crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
315
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
316
+ elif tile_mode == 2: # hw都减半
317
+ crop_size = (
318
+ ((h0 - 1) // 4 * 4 + 4) // 2,
319
+ ((w0 - 1) // 4 * 4 + 4) // 2,
320
+ ) # 5.6G
321
+ elif tile_mode == 3: # hw都三分之一
322
+ crop_size = (
323
+ ((h0 - 1) // 6 * 6 + 6) // 3,
324
+ ((w0 - 1) // 6 * 6 + 6) // 3,
325
+ ) # 4.2G
326
+ elif tile_mode == 4: # hw都四分之一
327
+ crop_size = (
328
+ ((h0 - 1) // 8 * 8 + 8) // 4,
329
+ ((w0 - 1) // 8 * 8 + 8) // 4,
330
+ ) # 3.7G
331
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
332
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
333
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect")
334
+ n, c, h, w = x.shape
335
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
336
+ if "Half" in x.type():
337
+ se_mean0 = se_mean0.half()
338
+ n_patch = 0
339
+ tmp_dict = {}
340
+ opt_res_dict = {}
341
+ for i in range(0, h - 36, crop_size[0]):
342
+ tmp_dict[i] = {}
343
+ for j in range(0, w - 36, crop_size[1]):
344
+ x_crop = x[:, :, i: i + crop_size[0] +
345
+ 36, j: j + crop_size[1] + 36]
346
+ n, c1, h1, w1 = x_crop.shape
347
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
348
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
349
+ tmp_se_mean = torch.mean(
350
+ x_crop.float(), dim=(2, 3), keepdim=True
351
+ ).half()
352
+ else:
353
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
354
+ se_mean0 += tmp_se_mean
355
+ n_patch += 1
356
+ tmp_dict[i][j] = (tmp0, x_crop)
357
+ se_mean0 /= n_patch
358
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
359
+ if "Half" in x.type():
360
+ se_mean1 = se_mean1.half()
361
+ for i in range(0, h - 36, crop_size[0]):
362
+ for j in range(0, w - 36, crop_size[1]):
363
+ tmp0, x_crop = tmp_dict[i][j]
364
+ x_crop = self.unet1.conv2.seblock.forward_mean(
365
+ x_crop, se_mean0)
366
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
367
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
368
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
369
+ tmp_se_mean = torch.mean(
370
+ tmp_x2.float(), dim=(2, 3), keepdim=True
371
+ ).half()
372
+ else:
373
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
374
+ se_mean1 += tmp_se_mean
375
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
376
+ se_mean1 /= n_patch
377
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
378
+ if "Half" in x.type():
379
+ se_mean0 = se_mean0.half()
380
+ for i in range(0, h - 36, crop_size[0]):
381
+ for j in range(0, w - 36, crop_size[1]):
382
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
383
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(
384
+ tmp_x2, se_mean1)
385
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
386
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
387
+ tmp_se_mean = torch.mean(
388
+ tmp_x3.float(), dim=(2, 3), keepdim=True
389
+ ).half()
390
+ else:
391
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
392
+ se_mean0 += tmp_se_mean
393
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
394
+ se_mean0 /= n_patch
395
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
396
+ if "Half" in x.type():
397
+ se_mean1 = se_mean1.half()
398
+ for i in range(0, h - 36, crop_size[0]):
399
+ for j in range(0, w - 36, crop_size[1]):
400
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
401
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(
402
+ tmp_x3, se_mean0)
403
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
404
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
405
+ tmp_se_mean = torch.mean(
406
+ tmp_x4.float(), dim=(2, 3), keepdim=True
407
+ ).half()
408
+ else:
409
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
410
+ se_mean1 += tmp_se_mean
411
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
412
+ se_mean1 /= n_patch
413
+ for i in range(0, h - 36, crop_size[0]):
414
+ opt_res_dict[i] = {}
415
+ for j in range(0, w - 36, crop_size[1]):
416
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
417
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(
418
+ tmp_x4, se_mean1)
419
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
420
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
421
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
422
+ opt_res_dict[i][j] = x_crop
423
+ del tmp_dict
424
+ torch.cuda.empty_cache()
425
+ res = torch.zeros((n, c, h * 2 - 72, w * 2 - 72)).to(x.device)
426
+ if "Half" in x.type():
427
+ res = res.half()
428
+ for i in range(0, h - 36, crop_size[0]):
429
+ for j in range(0, w - 36, crop_size[1]):
430
+ res[
431
+ :, :, i * 2: i * 2 + h1 * 2 - 72, j * 2: j * 2 + w1 * 2 - 72
432
+ ] = opt_res_dict[i][j]
433
+ del opt_res_dict
434
+ torch.cuda.empty_cache()
435
+ if w0 != pw or h0 != ph:
436
+ res = res[:, :, : h0 * 2, : w0 * 2]
437
+ return res #
438
+
439
+
440
+ class UpCunet3x(nn.Module): # 完美tile,全程无损
441
+ def __init__(self, in_channels=3, out_channels=3):
442
+ super(UpCunet3x, self).__init__()
443
+ self.unet1 = UNet1x3(in_channels, out_channels, deconv=True)
444
+ self.unet2 = UNet2(in_channels, out_channels, deconv=False)
445
+
446
+ def forward(self, x, tile_mode): # 1.7G
447
+ n, c, h0, w0 = x.shape
448
+ if tile_mode == 0: # 不tile
449
+ ph = ((h0 - 1) // 4 + 1) * 4
450
+ pw = ((w0 - 1) // 4 + 1) * 4
451
+ x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0),
452
+ "reflect") # 需要保证被2整除
453
+ x = self.unet1.forward(x)
454
+ x0 = self.unet2.forward(x)
455
+ x1 = F.pad(x, (-20, -20, -20, -20))
456
+ x = torch.add(x0, x1)
457
+ if w0 != pw or h0 != ph:
458
+ x = x[:, :, : h0 * 3, : w0 * 3]
459
+ return x
460
+ elif tile_mode == 1: # 对长边减半
461
+ if w0 >= h0:
462
+ crop_size_w = ((w0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
463
+ crop_size_h = (h0 - 1) // 4 * 4 + 4 # 能被4整除
464
+ else:
465
+ crop_size_h = ((h0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除
466
+ crop_size_w = (w0 - 1) // 4 * 4 + 4 # 能被4整除
467
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
468
+ elif tile_mode == 2: # hw都减半
469
+ crop_size = (
470
+ ((h0 - 1) // 8 * 8 + 8) // 2,
471
+ ((w0 - 1) // 8 * 8 + 8) // 2,
472
+ ) # 5.6G
473
+ elif tile_mode == 3: # hw都三分之一
474
+ crop_size = (
475
+ ((h0 - 1) // 12 * 12 + 12) // 3,
476
+ ((w0 - 1) // 12 * 12 + 12) // 3,
477
+ ) # 4.2G
478
+ elif tile_mode == 4: # hw都四分之一
479
+ crop_size = (
480
+ ((h0 - 1) // 16 * 16 + 16) // 4,
481
+ ((w0 - 1) // 16 * 16 + 16) // 4,
482
+ ) # 3.7G
483
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
484
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
485
+ x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
486
+ n, c, h, w = x.shape
487
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
488
+ if "Half" in x.type():
489
+ se_mean0 = se_mean0.half()
490
+ n_patch = 0
491
+ tmp_dict = {}
492
+ opt_res_dict = {}
493
+ for i in range(0, h - 28, crop_size[0]):
494
+ tmp_dict[i] = {}
495
+ for j in range(0, w - 28, crop_size[1]):
496
+ x_crop = x[:, :, i: i + crop_size[0] +
497
+ 28, j: j + crop_size[1] + 28]
498
+ n, c1, h1, w1 = x_crop.shape
499
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
500
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
501
+ tmp_se_mean = torch.mean(
502
+ x_crop.float(), dim=(2, 3), keepdim=True
503
+ ).half()
504
+ else:
505
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
506
+ se_mean0 += tmp_se_mean
507
+ n_patch += 1
508
+ tmp_dict[i][j] = (tmp0, x_crop)
509
+ se_mean0 /= n_patch
510
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
511
+ if "Half" in x.type():
512
+ se_mean1 = se_mean1.half()
513
+ for i in range(0, h - 28, crop_size[0]):
514
+ for j in range(0, w - 28, crop_size[1]):
515
+ tmp0, x_crop = tmp_dict[i][j]
516
+ x_crop = self.unet1.conv2.seblock.forward_mean(
517
+ x_crop, se_mean0)
518
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
519
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
520
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
521
+ tmp_se_mean = torch.mean(
522
+ tmp_x2.float(), dim=(2, 3), keepdim=True
523
+ ).half()
524
+ else:
525
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
526
+ se_mean1 += tmp_se_mean
527
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
528
+ se_mean1 /= n_patch
529
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
530
+ if "Half" in x.type():
531
+ se_mean0 = se_mean0.half()
532
+ for i in range(0, h - 28, crop_size[0]):
533
+ for j in range(0, w - 28, crop_size[1]):
534
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
535
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(
536
+ tmp_x2, se_mean1)
537
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
538
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
539
+ tmp_se_mean = torch.mean(
540
+ tmp_x3.float(), dim=(2, 3), keepdim=True
541
+ ).half()
542
+ else:
543
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
544
+ se_mean0 += tmp_se_mean
545
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
546
+ se_mean0 /= n_patch
547
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
548
+ if "Half" in x.type():
549
+ se_mean1 = se_mean1.half()
550
+ for i in range(0, h - 28, crop_size[0]):
551
+ for j in range(0, w - 28, crop_size[1]):
552
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
553
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(
554
+ tmp_x3, se_mean0)
555
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
556
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
557
+ tmp_se_mean = torch.mean(
558
+ tmp_x4.float(), dim=(2, 3), keepdim=True
559
+ ).half()
560
+ else:
561
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
562
+ se_mean1 += tmp_se_mean
563
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
564
+ se_mean1 /= n_patch
565
+ for i in range(0, h - 28, crop_size[0]):
566
+ opt_res_dict[i] = {}
567
+ for j in range(0, w - 28, crop_size[1]):
568
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
569
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(
570
+ tmp_x4, se_mean1)
571
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
572
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
573
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
574
+ opt_res_dict[i][j] = x_crop #
575
+ del tmp_dict
576
+ torch.cuda.empty_cache()
577
+ res = torch.zeros((n, c, h * 3 - 84, w * 3 - 84)).to(x.device)
578
+ if "Half" in x.type():
579
+ res = res.half()
580
+ for i in range(0, h - 28, crop_size[0]):
581
+ for j in range(0, w - 28, crop_size[1]):
582
+ res[
583
+ :, :, i * 3: i * 3 + h1 * 3 - 84, j * 3: j * 3 + w1 * 3 - 84
584
+ ] = opt_res_dict[i][j]
585
+ del opt_res_dict
586
+ torch.cuda.empty_cache()
587
+ if w0 != pw or h0 != ph:
588
+ res = res[:, :, : h0 * 3, : w0 * 3]
589
+ return res
590
+
591
+
592
+ class UpCunet4x(nn.Module): # 完美tile,全程无损
593
+ def __init__(self, in_channels=3, out_channels=3):
594
+ super(UpCunet4x, self).__init__()
595
+ self.unet1 = UNet1(in_channels, 64, deconv=True)
596
+ self.unet2 = UNet2(64, 64, deconv=False)
597
+ self.ps = nn.PixelShuffle(2)
598
+ self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True)
599
+
600
+ def forward(self, x, tile_mode):
601
+ n, c, h0, w0 = x.shape
602
+ x00 = x
603
+ if tile_mode == 0: # 不tile
604
+ ph = ((h0 - 1) // 2 + 1) * 2
605
+ pw = ((w0 - 1) // 2 + 1) * 2
606
+ x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0),
607
+ "reflect") # 需要保证被2整除
608
+ x = self.unet1.forward(x)
609
+ x0 = self.unet2.forward(x)
610
+ x1 = F.pad(x, (-20, -20, -20, -20))
611
+ x = torch.add(x0, x1)
612
+ x = self.conv_final(x)
613
+ x = F.pad(x, (-1, -1, -1, -1))
614
+ x = self.ps(x)
615
+ if w0 != pw or h0 != ph:
616
+ x = x[:, :, : h0 * 4, : w0 * 4]
617
+ x += F.interpolate(x00, scale_factor=4, mode="nearest")
618
+ return x
619
+ elif tile_mode == 1: # 对长边减半
620
+ if w0 >= h0:
621
+ crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
622
+ crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除
623
+ else:
624
+ crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除
625
+ crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除
626
+ crop_size = (crop_size_h, crop_size_w) # 6.6G
627
+ elif tile_mode == 2: # hw都减半
628
+ crop_size = (
629
+ ((h0 - 1) // 4 * 4 + 4) // 2,
630
+ ((w0 - 1) // 4 * 4 + 4) // 2,
631
+ ) # 5.6G
632
+ elif tile_mode == 3: # hw都三分之一
633
+ crop_size = (
634
+ ((h0 - 1) // 6 * 6 + 6) // 3,
635
+ ((w0 - 1) // 6 * 6 + 6) // 3,
636
+ ) # 4.1G
637
+ elif tile_mode == 4: # hw都四分之一
638
+ crop_size = (
639
+ ((h0 - 1) // 8 * 8 + 8) // 4,
640
+ ((w0 - 1) // 8 * 8 + 8) // 4,
641
+ ) # 3.7G
642
+ ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
643
+ pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
644
+ x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")
645
+ n, c, h, w = x.shape
646
+ se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device)
647
+ if "Half" in x.type():
648
+ se_mean0 = se_mean0.half()
649
+ n_patch = 0
650
+ tmp_dict = {}
651
+ opt_res_dict = {}
652
+ for i in range(0, h - 38, crop_size[0]):
653
+ tmp_dict[i] = {}
654
+ for j in range(0, w - 38, crop_size[1]):
655
+ x_crop = x[:, :, i: i + crop_size[0] +
656
+ 38, j: j + crop_size[1] + 38]
657
+ n, c1, h1, w1 = x_crop.shape
658
+ tmp0, x_crop = self.unet1.forward_a(x_crop)
659
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
660
+ tmp_se_mean = torch.mean(
661
+ x_crop.float(), dim=(2, 3), keepdim=True
662
+ ).half()
663
+ else:
664
+ tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True)
665
+ se_mean0 += tmp_se_mean
666
+ n_patch += 1
667
+ tmp_dict[i][j] = (tmp0, x_crop)
668
+ se_mean0 /= n_patch
669
+ se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
670
+ if "Half" in x.type():
671
+ se_mean1 = se_mean1.half()
672
+ for i in range(0, h - 38, crop_size[0]):
673
+ for j in range(0, w - 38, crop_size[1]):
674
+ tmp0, x_crop = tmp_dict[i][j]
675
+ x_crop = self.unet1.conv2.seblock.forward_mean(
676
+ x_crop, se_mean0)
677
+ opt_unet1 = self.unet1.forward_b(tmp0, x_crop)
678
+ tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1)
679
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
680
+ tmp_se_mean = torch.mean(
681
+ tmp_x2.float(), dim=(2, 3), keepdim=True
682
+ ).half()
683
+ else:
684
+ tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True)
685
+ se_mean1 += tmp_se_mean
686
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2)
687
+ se_mean1 /= n_patch
688
+ se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64
689
+ if "Half" in x.type():
690
+ se_mean0 = se_mean0.half()
691
+ for i in range(0, h - 38, crop_size[0]):
692
+ for j in range(0, w - 38, crop_size[1]):
693
+ opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j]
694
+ tmp_x2 = self.unet2.conv2.seblock.forward_mean(
695
+ tmp_x2, se_mean1)
696
+ tmp_x3 = self.unet2.forward_b(tmp_x2)
697
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
698
+ tmp_se_mean = torch.mean(
699
+ tmp_x3.float(), dim=(2, 3), keepdim=True
700
+ ).half()
701
+ else:
702
+ tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True)
703
+ se_mean0 += tmp_se_mean
704
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3)
705
+ se_mean0 /= n_patch
706
+ se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64
707
+ if "Half" in x.type():
708
+ se_mean1 = se_mean1.half()
709
+ for i in range(0, h - 38, crop_size[0]):
710
+ for j in range(0, w - 38, crop_size[1]):
711
+ opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j]
712
+ tmp_x3 = self.unet2.conv3.seblock.forward_mean(
713
+ tmp_x3, se_mean0)
714
+ tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3)
715
+ if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor
716
+ tmp_se_mean = torch.mean(
717
+ tmp_x4.float(), dim=(2, 3), keepdim=True
718
+ ).half()
719
+ else:
720
+ tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True)
721
+ se_mean1 += tmp_se_mean
722
+ tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4)
723
+ se_mean1 /= n_patch
724
+ for i in range(0, h - 38, crop_size[0]):
725
+ opt_res_dict[i] = {}
726
+ for j in range(0, w - 38, crop_size[1]):
727
+ opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j]
728
+ tmp_x4 = self.unet2.conv4.seblock.forward_mean(
729
+ tmp_x4, se_mean1)
730
+ x0 = self.unet2.forward_d(tmp_x1, tmp_x4)
731
+ x1 = F.pad(opt_unet1, (-20, -20, -20, -20))
732
+ x_crop = torch.add(x0, x1) # x0是unet2的最终输出
733
+ x_crop = self.conv_final(x_crop)
734
+ x_crop = F.pad(x_crop, (-1, -1, -1, -1))
735
+ x_crop = self.ps(x_crop)
736
+ opt_res_dict[i][j] = x_crop
737
+ del tmp_dict
738
+ torch.cuda.empty_cache()
739
+ res = torch.zeros((n, c, h * 4 - 152, w * 4 - 152)).to(x.device)
740
+ if "Half" in x.type():
741
+ res = res.half()
742
+ for i in range(0, h - 38, crop_size[0]):
743
+ for j in range(0, w - 38, crop_size[1]):
744
+ # print(opt_res_dict[i][j].shape,res[:, :, i * 4:i * 4 + h1 * 4 - 144, j * 4:j * 4 + w1 * 4 - 144].shape)
745
+ res[
746
+ :, :, i * 4: i * 4 + h1 * 4 - 152, j * 4: j * 4 + w1 * 4 - 152
747
+ ] = opt_res_dict[i][j]
748
+ del opt_res_dict
749
+ torch.cuda.empty_cache()
750
+ if w0 != pw or h0 != ph:
751
+ res = res[:, :, : h0 * 4, : w0 * 4]
752
+ res += F.interpolate(x00, scale_factor=4, mode="nearest")
753
+ return res #
754
+
755
+
756
+ models: Dict[str, Type[nn.Module]] = {
757
+ obj.__name__: obj
758
+ for obj in globals().values()
759
+ if isinstance(obj, type) and issubclass(obj, nn.Module)
760
+ }
761
+
762
+
763
+ class RealWaifuUpScaler:
764
+ def __init__(self, scale: int, weight_path: str, half: bool, device: str):
765
+ weight = torch.load(weight_path, map_location=device)
766
+ self.model = models[f"UpCunet{scale}x"]()
767
+
768
+ if half == True:
769
+ self.model = self.model.half().to(device)
770
+ else:
771
+ self.model = self.model.to(device)
772
+
773
+ self.model.load_state_dict(weight, strict=True)
774
+ self.model.eval()
775
+
776
+ self.half = half
777
+ self.device = device
778
+
779
+ def np2tensor(self, np_frame):
780
+ if self.half == False:
781
+ return (
782
+ torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
783
+ .unsqueeze(0)
784
+ .to(self.device)
785
+ .float()
786
+ / 255
787
+ )
788
+ else:
789
+ return (
790
+ torch.from_numpy(np.transpose(np_frame, (2, 0, 1)))
791
+ .unsqueeze(0)
792
+ .to(self.device)
793
+ .half()
794
+ / 255
795
+ )
796
+
797
+ def tensor2np(self, tensor):
798
+ if self.half == False:
799
+ return np.transpose(
800
+ (tensor.data.squeeze() * 255.0)
801
+ .round()
802
+ .clamp_(0, 255)
803
+ .byte()
804
+ .cpu()
805
+ .numpy(),
806
+ (1, 2, 0),
807
+ )
808
+ else:
809
+ return np.transpose(
810
+ (tensor.data.squeeze().float() * 255.0)
811
+ .round()
812
+ .clamp_(0, 255)
813
+ .byte()
814
+ .cpu()
815
+ .numpy(),
816
+ (1, 2, 0),
817
+ )
818
+
819
+ def __call__(self, frame, tile_mode):
820
+ with torch.no_grad():
821
+ tensor = self.np2tensor(frame)
822
+ result = self.tensor2np(self.model(tensor, tile_mode))
823
+ return result
824
+
825
+
826
+ input_image = inputs.File(label="Input image")
827
+ half_precision = inputs.Checkbox(
828
+ label="Half precision (NOT work for CPU)", default=False
829
+ )
830
+ model_weight = inputs.Dropdown(
831
+ sorted(AVALIABLE_WEIGHTS), label="Choice model weight")
832
+ tile_mode = inputs.Radio(
833
+ [mode.name for mode in TileMode], label="Output tile mode")
834
+
835
+ output_image = outputs.Image(label="Output image preview")
836
+ output_file = outputs.File(label="Output image file")
837
+
838
+
839
+ def main(file: IO[bytes], half: bool, weight: str, tile: str):
840
+ scale = next(
841
+ mode.value for mode in ScaleMode if weight.startswith(mode.name))
842
+ upscaler = RealWaifuUpScaler(
843
+ scale, weight_path=str(AVALIABLE_WEIGHTS[weight]), half=half, device=DEVICE
844
+ )
845
+
846
+ frame = cv2.cvtColor(cv2.imread(file.name), cv2.COLOR_BGR2RGB)
847
+ result = cv2.cvtColor(upscaler(frame, TileMode[tile]), cv2.COLOR_RGB2BGR)
848
+
849
+ _, ext = os.path.splitext(file.name)
850
+ tempfile = mktemp(suffix=ext)
851
+ cv2.imwrite(tempfile, result)
852
+ return tempfile, tempfile
853
+
854
+
855
+ interface = Interface(
856
+ main,
857
+ inputs=[input_image, half_precision, model_weight, tile_mode],
858
+ outputs=[output_image, output_file],
859
+ )
860
+ interface.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ opencv-python
3
+ numpy
4
+ gradio
5
+ jinja2
weights/up2x-latest-conservative.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cfe3b23687915d08ba96010f25198d9cfe8a683aa4131f1acf7eaa58ee1de93
3
+ size 5147249
weights/up2x-latest-denoise1x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e783c39da6a6394fbc250fdd069c55eaedc43971c4f2405322f18949ce38573
3
+ size 5147249
weights/up2x-latest-denoise2x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8188b3faef4258cf748c59360cbc8086ebedf4a63eb9d5d6637d45f819d32496
3
+ size 5147249
weights/up2x-latest-denoise3x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a14739f3f5fcbd74ec3ce2806d13a47916c916b20afe4a39d95f6df4ca6abd8
3
+ size 5147249
weights/up2x-latest-no-denoise.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f491f9ecf6964ead9f3a36bf03e83527f32c6a341b683f7378ac6c1e2a5f0d16
3
+ size 5147249
weights/up3x-latest-conservative.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6ea5fd20380413beb2701182483fd80c2e86f3b3f08053eb3df4975184aefe3
3
+ size 5154161
weights/up3x-latest-denoise3x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39f1e6e90d50e5528a63f4ba1866bad23365a737cbea22a80769b2ec4c1c3285
3
+ size 5154161
weights/up3x-latest-no-denoise.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:763f0a87e70d744673f1a41db5396d5f334d22de97fff68ffc40deb91404a584
3
+ size 5154161
weights/up4x-latest-conservative.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8c8185def699b0883662a02df0ef2e6db3b0275170b6cc0d28089b64b273427
3
+ size 5636403
weights/up4x-latest-denoise3x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42bd8fcdae37c12c5b25ed59625266bfa65780071a8d38192d83756cb85e98dd
3
+ size 5636403
weights/up4x-latest-no-denoise.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaf3ef78a488cce5d3842154925eb70ff8423b8298e2cd189ec66eb7f6f66fae
3
+ size 5636403