wzhouxiff commited on
Commit
f461022
1 Parent(s): cb74f6a
Files changed (7) hide show
  1. .gitignore +1 -0
  2. .vscode/sftp.json +24 -0
  3. RestoreFormer.py +117 -0
  4. RestoreFormer_arch.py +742 -0
  5. app.py +132 -0
  6. packages.txt +3 -0
  7. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ model_bk*
.vscode/sftp.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "wzhoux",
3
+ "host": "9.134.229.18",
4
+ "protocol": "sftp",
5
+ "port": 36000,
6
+ "username": "root",
7
+ "remotePath": "/group/30042/zhouxiawang/project/gradio/RestoreFormerPlusPlus",
8
+ "uploadOnSave": true,
9
+ "password": "Beagirl12#",
10
+ "ignore": [
11
+ ".vscode",
12
+ ".git",
13
+ ".DS_Store",
14
+ ".conda",
15
+ "./models",
16
+ "./logs",
17
+ "outputs",
18
+ "eggs",
19
+ ".eggs",
20
+ "logs",
21
+ "experiments",
22
+ "./results"
23
+ ]
24
+ }
RestoreFormer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+ from basicsr.utils import img2tensor, tensor2img
6
+ from basicsr.utils.download_util import load_file_from_url
7
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
8
+ from torchvision.transforms.functional import normalize
9
+
10
+ from RestoreFormer_arch import VQVAEGANMultiHeadTransformer
11
+
12
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+
15
+ class RestoreFormer():
16
+ """Helper for restoration with RestoreFormer.
17
+
18
+ It will detect and crop faces, and then resize the faces to 512x512.
19
+ RestoreFormer is used to restored the resized faces.
20
+ The background is upsampled with the bg_upsampler.
21
+ Finally, the faces will be pasted back to the upsample background image.
22
+
23
+ Args:
24
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
25
+ upscale (float): The upscale of the final output. Default: 2.
26
+ arch (str): The RestoreFormer architecture. Option: RestoreFormer | RestoreFormer++. Default: RestoreFormer++.
27
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
28
+ """
29
+
30
+ def __init__(self, model_path, upscale=2, arch='RestoreFromerPlusPlus', bg_upsampler=None, device=None):
31
+ self.upscale = upscale
32
+ self.bg_upsampler = bg_upsampler
33
+ self.arch = arch
34
+
35
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
36
+
37
+ if arch == 'RestoreFormer':
38
+ self.RF = VQVAEGANMultiHeadTransformer(head_size = 8, ex_multi_scale_num = 0)
39
+ elif arch == 'RestoreFormer++':
40
+ self.RF = VQVAEGANMultiHeadTransformer(head_size = 4, ex_multi_scale_num = 1)
41
+ else:
42
+ raise NotImplementedError(f'Not support arch: {arch}.')
43
+
44
+ # initialize face helper
45
+ self.face_helper = FaceRestoreHelper(
46
+ upscale,
47
+ face_size=512,
48
+ crop_ratio=(1, 1),
49
+ det_model='retinaface_resnet50',
50
+ save_ext='png',
51
+ use_parse=True,
52
+ device=self.device,
53
+ model_rootpath=None)
54
+
55
+ if model_path.startswith('https://'):
56
+ model_path = load_file_from_url(
57
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'experiments/weights'), progress=True, file_name=None)
58
+ loadnet = torch.load(model_path)
59
+
60
+ strict=False
61
+ weights = loadnet['state_dict']
62
+ new_weights = {}
63
+ for k, v in weights.items():
64
+ if k.startswith('vqvae.'):
65
+ k = k.replace('vqvae.', '')
66
+ new_weights[k] = v
67
+ self.RF.load_state_dict(new_weights, strict=strict)
68
+
69
+ self.RF.eval()
70
+ self.RF = self.RF.to(self.device)
71
+
72
+ @torch.no_grad()
73
+ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
74
+ self.face_helper.clean_all()
75
+
76
+ if has_aligned: # the inputs are already aligned
77
+ img = cv2.resize(img, (512, 512))
78
+ self.face_helper.cropped_faces = [img]
79
+ else:
80
+ self.face_helper.read_image(img)
81
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
82
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
83
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
84
+ # align and warp each face
85
+ self.face_helper.align_warp_face()
86
+
87
+ # face restoration
88
+ for cropped_face in self.face_helper.cropped_faces:
89
+ # prepare data
90
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
91
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
92
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
93
+
94
+ try:
95
+ output = self.RF(cropped_face_t)[0]
96
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
97
+ except RuntimeError as error:
98
+ print(f'\tFailed inference for RestoreFormer: {error}.')
99
+ restored_face = cropped_face
100
+
101
+ restored_face = restored_face.astype('uint8')
102
+ self.face_helper.add_restored_face(restored_face)
103
+
104
+ if not has_aligned and paste_back:
105
+ # upsample the background
106
+ if self.bg_upsampler is not None:
107
+ # Now only support RealESRGAN for upsampling background
108
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
109
+ else:
110
+ bg_img = None
111
+
112
+ self.face_helper.get_inverse_affine(None)
113
+ # paste each restored face to the input image
114
+ restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
115
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
116
+ else:
117
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
RestoreFormer_arch.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class VectorQuantizer(nn.Module):
8
+ """
9
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
10
+ ____________________________________________
11
+ Discretization bottleneck part of the VQ-VAE.
12
+ Inputs:
13
+ - n_e : number of embeddings
14
+ - e_dim : dimension of embedding
15
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
16
+ _____________________________________________
17
+ """
18
+
19
+ def __init__(self, n_e, e_dim, beta):
20
+ super(VectorQuantizer, self).__init__()
21
+ self.n_e = n_e
22
+ self.e_dim = e_dim
23
+ self.beta = beta
24
+
25
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
27
+
28
+ def forward(self, z):
29
+ """
30
+ Inputs the output of the encoder network z and maps it to a discrete
31
+ one-hot vector that is the index of the closest embedding vector e_j
32
+ z (continuous) -> z_q (discrete)
33
+ z.shape = (batch, channel, height, width)
34
+ quantization pipeline:
35
+ 1. get encoder input (B,C,H,W)
36
+ 2. flatten input to (B*H*W,C)
37
+ """
38
+ # reshape z -> (batch, height, width, channel) and flatten
39
+ z = z.permute(0, 2, 3, 1).contiguous()
40
+ z_flattened = z.view(-1, self.e_dim)
41
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
42
+
43
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
44
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
45
+ torch.matmul(z_flattened, self.embedding.weight.t())
46
+
47
+ ## could possible replace this here
48
+ # #\start...
49
+ # find closest encodings
50
+
51
+ min_value, min_encoding_indices = torch.min(d, dim=1)
52
+
53
+ min_encoding_indices = min_encoding_indices.unsqueeze(1)
54
+
55
+ min_encodings = torch.zeros(
56
+ min_encoding_indices.shape[0], self.n_e).to(z)
57
+ min_encodings.scatter_(1, min_encoding_indices, 1)
58
+
59
+ # dtype min encodings: torch.float32
60
+ # min_encodings shape: torch.Size([2048, 512])
61
+ # min_encoding_indices.shape: torch.Size([2048, 1])
62
+
63
+ # get quantized latent vectors
64
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
65
+ #.........\end
66
+
67
+ # with:
68
+ # .........\start
69
+ #min_encoding_indices = torch.argmin(d, dim=1)
70
+ #z_q = self.embedding(min_encoding_indices)
71
+ # ......\end......... (TODO)
72
+
73
+ # compute loss for embedding
74
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
75
+ torch.mean((z_q - z.detach()) ** 2)
76
+
77
+ # preserve gradients
78
+ z_q = z + (z_q - z).detach()
79
+
80
+ # perplexity
81
+
82
+ e_mean = torch.mean(min_encodings, dim=0)
83
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
84
+
85
+ # reshape back to match original input shape
86
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
87
+
88
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
89
+
90
+ def get_codebook_entry(self, indices, shape):
91
+ # shape specifying (batch, height, width, channel)
92
+ # TODO: check for more easy handling with nn.Embedding
93
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
94
+ min_encodings.scatter_(1, indices[:,None], 1)
95
+
96
+ # get quantized latent vectors
97
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
98
+
99
+ if shape is not None:
100
+ z_q = z_q.view(shape)
101
+
102
+ # reshape back to match original input shape
103
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
104
+
105
+ return z_q
106
+
107
+ # pytorch_diffusion + derived encoder decoder
108
+ def nonlinearity(x):
109
+ # swish
110
+ return x*torch.sigmoid(x)
111
+
112
+
113
+ def Normalize(in_channels):
114
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
115
+
116
+
117
+ class Upsample(nn.Module):
118
+ def __init__(self, in_channels, with_conv):
119
+ super().__init__()
120
+ self.with_conv = with_conv
121
+ if self.with_conv:
122
+ self.conv = torch.nn.Conv2d(in_channels,
123
+ in_channels,
124
+ kernel_size=3,
125
+ stride=1,
126
+ padding=1)
127
+
128
+ def forward(self, x):
129
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
130
+ if self.with_conv:
131
+ x = self.conv(x)
132
+ return x
133
+
134
+
135
+ class Downsample(nn.Module):
136
+ def __init__(self, in_channels, with_conv):
137
+ super().__init__()
138
+ self.with_conv = with_conv
139
+ if self.with_conv:
140
+ # no asymmetric padding in torch conv, must do it ourselves
141
+ self.conv = torch.nn.Conv2d(in_channels,
142
+ in_channels,
143
+ kernel_size=3,
144
+ stride=2,
145
+ padding=0)
146
+
147
+ def forward(self, x):
148
+ if self.with_conv:
149
+ pad = (0,1,0,1)
150
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
151
+ x = self.conv(x)
152
+ else:
153
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
154
+ return x
155
+
156
+
157
+ class ResnetBlock(nn.Module):
158
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
159
+ dropout, temb_channels=512):
160
+ super().__init__()
161
+ self.in_channels = in_channels
162
+ out_channels = in_channels if out_channels is None else out_channels
163
+ self.out_channels = out_channels
164
+ self.use_conv_shortcut = conv_shortcut
165
+
166
+ self.norm1 = Normalize(in_channels)
167
+ self.conv1 = torch.nn.Conv2d(in_channels,
168
+ out_channels,
169
+ kernel_size=3,
170
+ stride=1,
171
+ padding=1)
172
+ if temb_channels > 0:
173
+ self.temb_proj = torch.nn.Linear(temb_channels,
174
+ out_channels)
175
+ self.norm2 = Normalize(out_channels)
176
+ self.dropout = torch.nn.Dropout(dropout)
177
+ self.conv2 = torch.nn.Conv2d(out_channels,
178
+ out_channels,
179
+ kernel_size=3,
180
+ stride=1,
181
+ padding=1)
182
+ if self.in_channels != self.out_channels:
183
+ if self.use_conv_shortcut:
184
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
185
+ out_channels,
186
+ kernel_size=3,
187
+ stride=1,
188
+ padding=1)
189
+ else:
190
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
191
+ out_channels,
192
+ kernel_size=1,
193
+ stride=1,
194
+ padding=0)
195
+
196
+ def forward(self, x, temb):
197
+ h = x
198
+ h = self.norm1(h)
199
+ h = nonlinearity(h)
200
+ h = self.conv1(h)
201
+
202
+ if temb is not None:
203
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
204
+
205
+ h = self.norm2(h)
206
+ h = nonlinearity(h)
207
+ h = self.dropout(h)
208
+ h = self.conv2(h)
209
+
210
+ if self.in_channels != self.out_channels:
211
+ if self.use_conv_shortcut:
212
+ x = self.conv_shortcut(x)
213
+ else:
214
+ x = self.nin_shortcut(x)
215
+
216
+ return x+h
217
+
218
+
219
+ class MultiHeadAttnBlock(nn.Module):
220
+ def __init__(self, in_channels, head_size=1):
221
+ super().__init__()
222
+ self.in_channels = in_channels
223
+ self.head_size = head_size
224
+ self.att_size = in_channels // head_size
225
+ assert(in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
226
+
227
+ self.norm1 = Normalize(in_channels)
228
+ self.norm2 = Normalize(in_channels)
229
+
230
+ self.q = torch.nn.Conv2d(in_channels,
231
+ in_channels,
232
+ kernel_size=1,
233
+ stride=1,
234
+ padding=0)
235
+ self.k = torch.nn.Conv2d(in_channels,
236
+ in_channels,
237
+ kernel_size=1,
238
+ stride=1,
239
+ padding=0)
240
+ self.v = torch.nn.Conv2d(in_channels,
241
+ in_channels,
242
+ kernel_size=1,
243
+ stride=1,
244
+ padding=0)
245
+ self.proj_out = torch.nn.Conv2d(in_channels,
246
+ in_channels,
247
+ kernel_size=1,
248
+ stride=1,
249
+ padding=0)
250
+ self.num = 0
251
+
252
+ def forward(self, x, y=None):
253
+ h_ = x
254
+ h_ = self.norm1(h_)
255
+ if y is None:
256
+ y = h_
257
+ else:
258
+ y = self.norm2(y)
259
+
260
+ q = self.q(y)
261
+ k = self.k(h_)
262
+ v = self.v(h_)
263
+
264
+ # compute attention
265
+ b,c,h,w = q.shape
266
+ q = q.reshape(b, self.head_size, self.att_size ,h*w)
267
+ q = q.permute(0, 3, 1, 2) # b, hw, head, att
268
+
269
+ k = k.reshape(b, self.head_size, self.att_size ,h*w)
270
+ k = k.permute(0, 3, 1, 2)
271
+
272
+ v = v.reshape(b, self.head_size, self.att_size ,h*w)
273
+ v = v.permute(0, 3, 1, 2)
274
+
275
+
276
+ q = q.transpose(1, 2)
277
+ v = v.transpose(1, 2)
278
+ k = k.transpose(1, 2).transpose(2,3)
279
+
280
+ scale = int(self.att_size)**(-0.5)
281
+ q.mul_(scale)
282
+ w_ = torch.matmul(q, k)
283
+ w_ = F.softmax(w_, dim=3)
284
+
285
+ w_ = w_.matmul(v)
286
+
287
+ w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
288
+ w_ = w_.view(b, h, w, -1)
289
+ w_ = w_.permute(0, 3, 1, 2)
290
+
291
+ w_ = self.proj_out(w_)
292
+
293
+ return x+w_
294
+
295
+
296
+ class MultiHeadEncoder(nn.Module):
297
+ def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
298
+ attn_resolutions=[16], dropout=0.0, resamp_with_conv=True, in_channels=3,
299
+ resolution=512, z_channels=256, double_z=True, enable_mid=True,
300
+ head_size=1, **ignore_kwargs):
301
+ super().__init__()
302
+ self.ch = ch
303
+ self.temb_ch = 0
304
+ self.num_resolutions = len(ch_mult)
305
+ self.num_res_blocks = num_res_blocks
306
+ self.resolution = resolution
307
+ self.in_channels = in_channels
308
+ self.enable_mid = enable_mid
309
+
310
+ # downsampling
311
+ self.conv_in = torch.nn.Conv2d(in_channels,
312
+ self.ch,
313
+ kernel_size=3,
314
+ stride=1,
315
+ padding=1)
316
+
317
+ curr_res = resolution
318
+ in_ch_mult = (1,)+tuple(ch_mult)
319
+ self.down = nn.ModuleList()
320
+ for i_level in range(self.num_resolutions):
321
+ block = nn.ModuleList()
322
+ attn = nn.ModuleList()
323
+ block_in = ch*in_ch_mult[i_level]
324
+ block_out = ch*ch_mult[i_level]
325
+ for i_block in range(self.num_res_blocks):
326
+ block.append(ResnetBlock(in_channels=block_in,
327
+ out_channels=block_out,
328
+ temb_channels=self.temb_ch,
329
+ dropout=dropout))
330
+ block_in = block_out
331
+ if curr_res in attn_resolutions:
332
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
333
+ down = nn.Module()
334
+ down.block = block
335
+ down.attn = attn
336
+ if i_level != self.num_resolutions-1:
337
+ down.downsample = Downsample(block_in, resamp_with_conv)
338
+ curr_res = curr_res // 2
339
+ self.down.append(down)
340
+
341
+ # middle
342
+ if self.enable_mid:
343
+ self.mid = nn.Module()
344
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
345
+ out_channels=block_in,
346
+ temb_channels=self.temb_ch,
347
+ dropout=dropout)
348
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
349
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
350
+ out_channels=block_in,
351
+ temb_channels=self.temb_ch,
352
+ dropout=dropout)
353
+
354
+ # end
355
+ self.norm_out = Normalize(block_in)
356
+ self.conv_out = torch.nn.Conv2d(block_in,
357
+ 2*z_channels if double_z else z_channels,
358
+ kernel_size=3,
359
+ stride=1,
360
+ padding=1)
361
+
362
+
363
+ def forward(self, x):
364
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
365
+
366
+ hs = {}
367
+ # timestep embedding
368
+ temb = None
369
+
370
+ # downsampling
371
+ h = self.conv_in(x)
372
+ hs['in'] = h
373
+ for i_level in range(self.num_resolutions):
374
+ for i_block in range(self.num_res_blocks):
375
+ h = self.down[i_level].block[i_block](h, temb)
376
+ if len(self.down[i_level].attn) > 0:
377
+ h = self.down[i_level].attn[i_block](h)
378
+
379
+ if i_level != self.num_resolutions-1:
380
+ # hs.append(h)
381
+ hs['block_'+str(i_level)] = h
382
+ h = self.down[i_level].downsample(h)
383
+
384
+ # middle
385
+ # h = hs[-1]
386
+ if self.enable_mid:
387
+ h = self.mid.block_1(h, temb)
388
+ hs['block_'+str(i_level)+'_atten'] = h
389
+ h = self.mid.attn_1(h)
390
+ h = self.mid.block_2(h, temb)
391
+ hs['mid_atten'] = h
392
+
393
+ # end
394
+ h = self.norm_out(h)
395
+ h = nonlinearity(h)
396
+ h = self.conv_out(h)
397
+ # hs.append(h)
398
+ hs['out'] = h
399
+
400
+ return hs
401
+
402
+ class MultiHeadDecoder(nn.Module):
403
+ def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
404
+ attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,
405
+ resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,
406
+ head_size=1, **ignorekwargs):
407
+ super().__init__()
408
+ self.ch = ch
409
+ self.temb_ch = 0
410
+ self.num_resolutions = len(ch_mult)
411
+ self.num_res_blocks = num_res_blocks
412
+ self.resolution = resolution
413
+ self.in_channels = in_channels
414
+ self.give_pre_end = give_pre_end
415
+ self.enable_mid = enable_mid
416
+
417
+ # compute in_ch_mult, block_in and curr_res at lowest res
418
+ in_ch_mult = (1,)+tuple(ch_mult)
419
+ block_in = ch*ch_mult[self.num_resolutions-1]
420
+ curr_res = resolution // 2**(self.num_resolutions-1)
421
+ self.z_shape = (1,z_channels,curr_res,curr_res)
422
+ print("Working with z of shape {} = {} dimensions.".format(
423
+ self.z_shape, np.prod(self.z_shape)))
424
+
425
+ # z to block_in
426
+ self.conv_in = torch.nn.Conv2d(z_channels,
427
+ block_in,
428
+ kernel_size=3,
429
+ stride=1,
430
+ padding=1)
431
+
432
+ # middle
433
+ if self.enable_mid:
434
+ self.mid = nn.Module()
435
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
436
+ out_channels=block_in,
437
+ temb_channels=self.temb_ch,
438
+ dropout=dropout)
439
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
440
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
441
+ out_channels=block_in,
442
+ temb_channels=self.temb_ch,
443
+ dropout=dropout)
444
+
445
+ # upsampling
446
+ self.up = nn.ModuleList()
447
+ for i_level in reversed(range(self.num_resolutions)):
448
+ block = nn.ModuleList()
449
+ attn = nn.ModuleList()
450
+ block_out = ch*ch_mult[i_level]
451
+ for i_block in range(self.num_res_blocks+1):
452
+ block.append(ResnetBlock(in_channels=block_in,
453
+ out_channels=block_out,
454
+ temb_channels=self.temb_ch,
455
+ dropout=dropout))
456
+ block_in = block_out
457
+ if curr_res in attn_resolutions:
458
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
459
+ up = nn.Module()
460
+ up.block = block
461
+ up.attn = attn
462
+ if i_level != 0:
463
+ up.upsample = Upsample(block_in, resamp_with_conv)
464
+ curr_res = curr_res * 2
465
+ self.up.insert(0, up) # prepend to get consistent order
466
+
467
+ # end
468
+ self.norm_out = Normalize(block_in)
469
+ self.conv_out = torch.nn.Conv2d(block_in,
470
+ out_ch,
471
+ kernel_size=3,
472
+ stride=1,
473
+ padding=1)
474
+
475
+ def forward(self, z):
476
+ #assert z.shape[1:] == self.z_shape[1:]
477
+ self.last_z_shape = z.shape
478
+
479
+ # timestep embedding
480
+ temb = None
481
+
482
+ # z to block_in
483
+ h = self.conv_in(z)
484
+
485
+ # middle
486
+ if self.enable_mid:
487
+ h = self.mid.block_1(h, temb)
488
+ h = self.mid.attn_1(h)
489
+ h = self.mid.block_2(h, temb)
490
+
491
+ # upsampling
492
+ for i_level in reversed(range(self.num_resolutions)):
493
+ for i_block in range(self.num_res_blocks+1):
494
+ h = self.up[i_level].block[i_block](h, temb)
495
+ if len(self.up[i_level].attn) > 0:
496
+ h = self.up[i_level].attn[i_block](h)
497
+ if i_level != 0:
498
+ h = self.up[i_level].upsample(h)
499
+
500
+ # end
501
+ if self.give_pre_end:
502
+ return h
503
+
504
+ h = self.norm_out(h)
505
+ h = nonlinearity(h)
506
+ h = self.conv_out(h)
507
+ return h
508
+
509
+ class MultiHeadDecoderTransformer(nn.Module):
510
+ def __init__(self, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks=2,
511
+ attn_resolutions=16, dropout=0.0, resamp_with_conv=True, in_channels=3,
512
+ resolution=512, z_channels=256, give_pre_end=False, enable_mid=True,
513
+ head_size=1, **ignorekwargs):
514
+ super().__init__()
515
+ self.ch = ch
516
+ self.temb_ch = 0
517
+ self.num_resolutions = len(ch_mult)
518
+ self.num_res_blocks = num_res_blocks
519
+ self.resolution = resolution
520
+ self.in_channels = in_channels
521
+ self.give_pre_end = give_pre_end
522
+ self.enable_mid = enable_mid
523
+
524
+ # compute in_ch_mult, block_in and curr_res at lowest res
525
+ in_ch_mult = (1,)+tuple(ch_mult)
526
+ block_in = ch*ch_mult[self.num_resolutions-1]
527
+ curr_res = resolution // 2**(self.num_resolutions-1)
528
+ self.z_shape = (1,z_channels,curr_res,curr_res)
529
+ print("Working with z of shape {} = {} dimensions.".format(
530
+ self.z_shape, np.prod(self.z_shape)))
531
+
532
+ # z to block_in
533
+ self.conv_in = torch.nn.Conv2d(z_channels,
534
+ block_in,
535
+ kernel_size=3,
536
+ stride=1,
537
+ padding=1)
538
+
539
+ # middle
540
+ if self.enable_mid:
541
+ self.mid = nn.Module()
542
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
543
+ out_channels=block_in,
544
+ temb_channels=self.temb_ch,
545
+ dropout=dropout)
546
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
547
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
548
+ out_channels=block_in,
549
+ temb_channels=self.temb_ch,
550
+ dropout=dropout)
551
+
552
+ # upsampling
553
+ self.up = nn.ModuleList()
554
+ for i_level in reversed(range(self.num_resolutions)):
555
+ block = nn.ModuleList()
556
+ attn = nn.ModuleList()
557
+ block_out = ch*ch_mult[i_level]
558
+ for i_block in range(self.num_res_blocks+1):
559
+ block.append(ResnetBlock(in_channels=block_in,
560
+ out_channels=block_out,
561
+ temb_channels=self.temb_ch,
562
+ dropout=dropout))
563
+ block_in = block_out
564
+ if curr_res in attn_resolutions:
565
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
566
+ up = nn.Module()
567
+ up.block = block
568
+ up.attn = attn
569
+ if i_level != 0:
570
+ up.upsample = Upsample(block_in, resamp_with_conv)
571
+ curr_res = curr_res * 2
572
+ self.up.insert(0, up) # prepend to get consistent order
573
+
574
+ # end
575
+ self.norm_out = Normalize(block_in)
576
+ self.conv_out = torch.nn.Conv2d(block_in,
577
+ out_ch,
578
+ kernel_size=3,
579
+ stride=1,
580
+ padding=1)
581
+
582
+ def forward(self, z, hs):
583
+ #assert z.shape[1:] == self.z_shape[1:]
584
+ # self.last_z_shape = z.shape
585
+
586
+ # timestep embedding
587
+ temb = None
588
+
589
+ # z to block_in
590
+ h = self.conv_in(z)
591
+
592
+ # middle
593
+ if self.enable_mid:
594
+ h = self.mid.block_1(h, temb)
595
+ h = self.mid.attn_1(h, hs['mid_atten'])
596
+ h = self.mid.block_2(h, temb)
597
+
598
+ # upsampling
599
+ for i_level in reversed(range(self.num_resolutions)):
600
+ for i_block in range(self.num_res_blocks+1):
601
+ h = self.up[i_level].block[i_block](h, temb)
602
+ if len(self.up[i_level].attn) > 0:
603
+ if 'block_'+str(i_level)+'_atten' in hs:
604
+ h = self.up[i_level].attn[i_block](h, hs['block_'+str(i_level)+'_atten'])
605
+ else:
606
+ h = self.up[i_level].attn[i_block](h, hs['block_'+str(i_level)])
607
+ if i_level != 0:
608
+ h = self.up[i_level].upsample(h)
609
+
610
+ # end
611
+ if self.give_pre_end:
612
+ return h
613
+
614
+ h = self.norm_out(h)
615
+ h = nonlinearity(h)
616
+ h = self.conv_out(h)
617
+ return h
618
+
619
+
620
+ class VQVAEGAN(nn.Module):
621
+ def __init__(self, n_embed=1024, embed_dim=256, ch=128, out_ch=3, ch_mult=(1,2,4,8),
622
+ num_res_blocks=2, attn_resolutions=16, dropout=0.0, in_channels=3,
623
+ resolution=512, z_channels=256, double_z=False, enable_mid=True,
624
+ fix_decoder=False, fix_codebook=False, head_size=1, **ignore_kwargs):
625
+ super(VQVAEGAN, self).__init__()
626
+
627
+ self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
628
+ attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
629
+ resolution=resolution, z_channels=z_channels, double_z=double_z,
630
+ enable_mid=enable_mid, head_size=head_size)
631
+ self.decoder = MultiHeadDecoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
632
+ attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
633
+ resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)
634
+
635
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
636
+
637
+ self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
638
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
639
+
640
+ if fix_decoder:
641
+ for _, param in self.decoder.named_parameters():
642
+ param.requires_grad = False
643
+ for _, param in self.post_quant_conv.named_parameters():
644
+ param.requires_grad = False
645
+ for _, param in self.quantize.named_parameters():
646
+ param.requires_grad = False
647
+ elif fix_codebook:
648
+ for _, param in self.quantize.named_parameters():
649
+ param.requires_grad = False
650
+
651
+ def encode(self, x):
652
+
653
+ hs = self.encoder(x)
654
+ h = self.quant_conv(hs['out'])
655
+ quant, emb_loss, info = self.quantize(h)
656
+ return quant, emb_loss, info, hs
657
+
658
+ def decode(self, quant):
659
+ quant = self.post_quant_conv(quant)
660
+ dec = self.decoder(quant)
661
+
662
+ return dec
663
+
664
+ def forward(self, input):
665
+ quant, diff, info, hs = self.encode(input)
666
+ dec = self.decode(quant)
667
+
668
+ return dec, diff, info, hs
669
+
670
+ class VQVAEGANMultiHeadTransformer(nn.Module):
671
+ def __init__(self,
672
+ n_embed=1024,
673
+ embed_dim=256,
674
+ ch=64,
675
+ out_ch=3,
676
+ ch_mult=(1, 2, 2, 4, 4, 8),
677
+ num_res_blocks=2,
678
+ attn_resolutions=(16, ),
679
+ dropout=0.0,
680
+ in_channels=3,
681
+ resolution=512,
682
+ z_channels=256,
683
+ double_z=False,
684
+ enable_mid=True,
685
+ fix_decoder=False,
686
+ fix_codebook=True,
687
+ fix_encoder=False,
688
+ head_size=4,
689
+ ex_multi_scale_num=1):
690
+ super(VQVAEGANMultiHeadTransformer, self).__init__()
691
+
692
+ self.encoder = MultiHeadEncoder(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
693
+ attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
694
+ resolution=resolution, z_channels=z_channels, double_z=double_z,
695
+ enable_mid=enable_mid, head_size=head_size)
696
+ for i in range(ex_multi_scale_num):
697
+ attn_resolutions = [attn_resolutions[0], attn_resolutions[-1]*2]
698
+ self.decoder = MultiHeadDecoderTransformer(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
699
+ attn_resolutions=attn_resolutions, dropout=dropout, in_channels=in_channels,
700
+ resolution=resolution, z_channels=z_channels, enable_mid=enable_mid, head_size=head_size)
701
+
702
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
703
+
704
+ self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
705
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
706
+
707
+ if fix_decoder:
708
+ for _, param in self.decoder.named_parameters():
709
+ param.requires_grad = False
710
+ for _, param in self.post_quant_conv.named_parameters():
711
+ param.requires_grad = False
712
+ for _, param in self.quantize.named_parameters():
713
+ param.requires_grad = False
714
+ elif fix_codebook:
715
+ for _, param in self.quantize.named_parameters():
716
+ param.requires_grad = False
717
+
718
+ if fix_encoder:
719
+ for _, param in self.encoder.named_parameters():
720
+ param.requires_grad = False
721
+ for _, param in self.quant_conv.named_parameters():
722
+ param.requires_grad = False
723
+
724
+
725
+ def encode(self, x):
726
+
727
+ hs = self.encoder(x)
728
+ h = self.quant_conv(hs['out'])
729
+ quant, emb_loss, info = self.quantize(h)
730
+ return quant, emb_loss, info, hs
731
+
732
+ def decode(self, quant, hs):
733
+ quant = self.post_quant_conv(quant)
734
+ dec = self.decoder(quant, hs)
735
+
736
+ return dec
737
+
738
+ def forward(self, input):
739
+ quant, diff, info, hs = self.encode(input)
740
+ dec = self.decode(quant, hs)
741
+
742
+ return dec, diff, info, hs
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import torch
6
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
7
+ from realesrgan.utils import RealESRGANer
8
+
9
+ from RestoreFormer import RestoreFormer
10
+
11
+ os.system("pip freeze")
12
+ # download weights
13
+ if not os.path.exists('realesr-general-x4v3.pth'):
14
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
15
+ if not os.path.exists('RestoreFormer.ckpt'):
16
+ os.system("wget https://github.com/wzhouxiff/RestoreFormerPlusPlus/releases/download/v1.0.0/RestoreFormer.ckpt -P .")
17
+ if not os.path.exists('RestoreFormer++.pth'):
18
+ os.system("wget https://github.com/wzhouxiff/RestoreFormerPlusPlus/releases/download/v1.0.0/RestoreFormer++.ckpt -P .")
19
+
20
+ # torch.hub.download_url_to_file(
21
+ # 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
22
+ # 'lincoln.jpg')
23
+ # torch.hub.download_url_to_file(
24
+ # 'https://user-images.githubusercontent.com/17445847/187400315-87a90ac9-d231-45d6-b377-38702bd1838f.jpg',
25
+ # 'AI-generate.jpg')
26
+ # torch.hub.download_url_to_file(
27
+ # 'https://user-images.githubusercontent.com/17445847/187400981-8a58f7a4-ef61-42d9-af80-bc6234cef860.jpg',
28
+ # 'Blake_Lively.jpg')
29
+ # torch.hub.download_url_to_file(
30
+ # 'https://user-images.githubusercontent.com/17445847/187401133-8a3bf269-5b4d-4432-b2f0-6d26ee1d3307.png',
31
+ # '10045.png')
32
+
33
+ # background enhancer with RealESRGAN
34
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
35
+ model_path = 'realesr-general-x4v3.pth'
36
+ half = True if torch.cuda.is_available() else False
37
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
38
+
39
+ os.makedirs('output', exist_ok=True)
40
+
41
+
42
+ # def inference(img, version, scale, weight):
43
+ def inference(img, version, scale):
44
+ # weight /= 100
45
+ print(img, version, scale)
46
+ if scale > 4:
47
+ scale = 4 # avoid too large scale value
48
+ try:
49
+ extension = os.path.splitext(os.path.basename(str(img)))[1]
50
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
51
+ if len(img.shape) == 3 and img.shape[2] == 4:
52
+ img_mode = 'RGBA'
53
+ elif len(img.shape) == 2: # for gray inputs
54
+ img_mode = None
55
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
56
+ else:
57
+ img_mode = None
58
+
59
+ h, w = img.shape[0:2]
60
+ if h > 3500 or w > 3500:
61
+ print('too large size')
62
+ return None, None
63
+
64
+ if h < 300:
65
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
66
+
67
+ if version == 'RestoreFormer':
68
+ face_enhancer = RestoreFormer(
69
+ model_path='RestoreFormer.ckpt', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
70
+ elif version == 'RestoreFormer++':
71
+ face_enhancer = RestoreFormer(
72
+ model_path='RestoreFormer++.ckpt', upscale=2, arch='RestoreFormer++', channel_multiplier=2, bg_upsampler=upsampler)
73
+
74
+ try:
75
+ # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
76
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
77
+ except RuntimeError as error:
78
+ print('Error', error)
79
+
80
+ try:
81
+ if scale != 2:
82
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
83
+ h, w = img.shape[0:2]
84
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
85
+ except Exception as error:
86
+ print('wrong scale input.', error)
87
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
88
+ extension = 'png'
89
+ else:
90
+ extension = 'jpg'
91
+ save_path = f'output/out.{extension}'
92
+ cv2.imwrite(save_path, output)
93
+
94
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
95
+ return output, save_path
96
+ except Exception as error:
97
+ print('global exception', error)
98
+ return None, None
99
+
100
+
101
+ title = "RestoreFormer: Blind Face Restoration Algorithm"
102
+ description = r"""Gradio demo for <a href='https://github.com/wzhouxiff/RestoreFormerPlusPlus' target='_blank'><b>RestoreFormer++: Towards Real-World Blind Face Restoration from Undegraded Key-Value Paris</b></a>.<br>
103
+ It is used to restore your **old photos**.<br>
104
+ To use it, simply upload your image.<br>
105
+ """
106
+ article = r"""
107
+ # [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
108
+ # [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
109
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2308.07228.pdf)
110
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_RestoreFormer_High-Quality_Blind_Face_Restoration_From_Undegraded_Key-Value_Pairs_CVPR_2022_paper.pdf)
111
+ If you have any question, please email 📧 `wzhoux@connect.hku.hk`.
112
+ # <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>
113
+ # <center><img src='https://visitor-badge.glitch.me/badge?page_id=Gradio_Xintao_GFPGAN' alt='visitor badge'></center>
114
+ """
115
+ demo = gr.Interface(
116
+ inference, [
117
+ gr.Image(type="filepath", label="Input"),
118
+ gr.Radio(['RestoreFormer', 'RestoreFormer++'], type="value", value='RestoreFormer++', label='version'),
119
+ gr.Number(label="Rescaling factor", value=2),
120
+ ], [
121
+ gr.Image(type="numpy", label="Output (The whole image)"),
122
+ gr.File(label="Download the output image")
123
+ ],
124
+ title=title,
125
+ description=description,
126
+ article=article,
127
+ # examples=[['AI-generate.jpg', 'v1.4', 2, 50], ['lincoln.jpg', 'v1.4', 2, 50], ['Blake_Lively.jpg', 'v1.4', 2, 50],
128
+ # ['10045.png', 'v1.4', 2, 50]]).launch()
129
+ # examples=[['AI-generate.jpg', 'v1.4', 2], ['lincoln.jpg', 'v1.4', 2], ['Blake_Lively.jpg', 'v1.4', 2],
130
+ # ['10045.png', 'v1.4', 2]]
131
+ )
132
+ demo.queue().launch(share=True)
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.7
2
+ basicsr>=1.4.2
3
+ facexlib>=0.2.5
4
+ realesrgan>=0.2.5
5
+ numpy
6
+ opencv-python
7
+ torchvision
8
+ scipy
9
+ tqdm
10
+ lmdb
11
+ pyyaml
12
+ yapf