AK391 commited on
Commit
13fd34d
1 Parent(s): 5e8cde2

update files

Browse files
VERSION CHANGED
@@ -1 +1 @@
1
- 0.2.1
 
1
+ 0.2.4
gfpgan/__init__.py CHANGED
@@ -3,4 +3,5 @@ from .archs import *
3
  from .data import *
4
  from .models import *
5
  from .utils import *
6
- #from .version import __gitsha__, __version__
 
 
3
  from .data import *
4
  from .models import *
5
  from .utils import *
6
+
7
+ # from .version import *
gfpgan/archs/arcface_arch.py CHANGED
@@ -2,13 +2,27 @@ import torch.nn as nn
2
  from basicsr.utils.registry import ARCH_REGISTRY
3
 
4
 
5
- def conv3x3(in_planes, out_planes, stride=1):
6
- """3x3 convolution with padding"""
7
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
 
 
 
 
 
 
8
 
9
 
10
  class BasicBlock(nn.Module):
11
- expansion = 1
 
 
 
 
 
 
 
 
12
 
13
  def __init__(self, inplanes, planes, stride=1, downsample=None):
14
  super(BasicBlock, self).__init__()
@@ -40,7 +54,16 @@ class BasicBlock(nn.Module):
40
 
41
 
42
  class IRBlock(nn.Module):
43
- expansion = 1
 
 
 
 
 
 
 
 
 
44
 
45
  def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
46
  super(IRBlock, self).__init__()
@@ -78,7 +101,15 @@ class IRBlock(nn.Module):
78
 
79
 
80
  class Bottleneck(nn.Module):
81
- expansion = 4
 
 
 
 
 
 
 
 
82
 
83
  def __init__(self, inplanes, planes, stride=1, downsample=None):
84
  super(Bottleneck, self).__init__()
@@ -116,10 +147,16 @@ class Bottleneck(nn.Module):
116
 
117
 
118
  class SEBlock(nn.Module):
 
 
 
 
 
 
119
 
120
  def __init__(self, channel, reduction=16):
121
  super(SEBlock, self).__init__()
122
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
123
  self.fc = nn.Sequential(
124
  nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
125
  nn.Sigmoid())
@@ -133,6 +170,15 @@ class SEBlock(nn.Module):
133
 
134
  @ARCH_REGISTRY.register()
135
  class ResNetArcFace(nn.Module):
 
 
 
 
 
 
 
 
 
136
 
137
  def __init__(self, block, layers, use_se=True):
138
  if block == 'IRBlock':
@@ -140,6 +186,7 @@ class ResNetArcFace(nn.Module):
140
  self.inplanes = 64
141
  self.use_se = use_se
142
  super(ResNetArcFace, self).__init__()
 
143
  self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
144
  self.bn1 = nn.BatchNorm2d(64)
145
  self.prelu = nn.PReLU()
@@ -153,6 +200,7 @@ class ResNetArcFace(nn.Module):
153
  self.fc5 = nn.Linear(512 * 8 * 8, 512)
154
  self.bn5 = nn.BatchNorm1d(512)
155
 
 
156
  for m in self.modules():
157
  if isinstance(m, nn.Conv2d):
158
  nn.init.xavier_normal_(m.weight)
@@ -163,7 +211,7 @@ class ResNetArcFace(nn.Module):
163
  nn.init.xavier_normal_(m.weight)
164
  nn.init.constant_(m.bias, 0)
165
 
166
- def _make_layer(self, block, planes, blocks, stride=1):
167
  downsample = None
168
  if stride != 1 or self.inplanes != planes * block.expansion:
169
  downsample = nn.Sequential(
@@ -173,7 +221,7 @@ class ResNetArcFace(nn.Module):
173
  layers = []
174
  layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
175
  self.inplanes = planes
176
- for _ in range(1, blocks):
177
  layers.append(block(self.inplanes, planes, use_se=self.use_se))
178
 
179
  return nn.Sequential(*layers)
 
2
  from basicsr.utils.registry import ARCH_REGISTRY
3
 
4
 
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
 
15
 
16
  class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
 
27
  def __init__(self, inplanes, planes, stride=1, downsample=None):
28
  super(BasicBlock, self).__init__()
 
54
 
55
 
56
  class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
 
68
  def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
  super(IRBlock, self).__init__()
 
101
 
102
 
103
  class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
 
114
  def __init__(self, inplanes, planes, stride=1, downsample=None):
115
  super(Bottleneck, self).__init__()
 
147
 
148
 
149
  class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
 
157
  def __init__(self, channel, reduction=16):
158
  super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
  self.fc = nn.Sequential(
161
  nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
  nn.Sigmoid())
 
170
 
171
  @ARCH_REGISTRY.register()
172
  class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
 
183
  def __init__(self, block, layers, use_se=True):
184
  if block == 'IRBlock':
 
186
  self.inplanes = 64
187
  self.use_se = use_se
188
  super(ResNetArcFace, self).__init__()
189
+
190
  self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
  self.bn1 = nn.BatchNorm2d(64)
192
  self.prelu = nn.PReLU()
 
200
  self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
  self.bn5 = nn.BatchNorm1d(512)
202
 
203
+ # initialization
204
  for m in self.modules():
205
  if isinstance(m, nn.Conv2d):
206
  nn.init.xavier_normal_(m.weight)
 
211
  nn.init.xavier_normal_(m.weight)
212
  nn.init.constant_(m.bias, 0)
213
 
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
  downsample = None
216
  if stride != 1 or self.inplanes != planes * block.expansion:
217
  downsample = nn.Sequential(
 
221
  layers = []
222
  layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
  self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
  layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
 
227
  return nn.Sequential(*layers)
gfpgan/archs/gfpganv1_arch.py CHANGED
@@ -10,18 +10,18 @@ from torch.nn import functional as F
10
 
11
 
12
  class StyleGAN2GeneratorSFT(StyleGAN2Generator):
13
- """StyleGAN2 Generator.
14
 
15
  Args:
16
  out_size (int): The spatial size of outputs.
17
  num_style_feat (int): Channel number of style features. Default: 512.
18
  num_mlp (int): Layer number of MLP style layers. Default: 8.
19
- channel_multiplier (int): Channel multiplier for large networks of
20
- StyleGAN2. Default: 2.
21
- resample_kernel (list[int]): A list indicating the 1D resample kernel
22
- magnitude. A cross production will be applied to extent 1D resample
23
- kenrel to 2D resample kernel. Default: [1, 3, 3, 1].
24
  lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
 
 
25
  """
26
 
27
  def __init__(self,
@@ -53,21 +53,18 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
53
  truncation_latent=None,
54
  inject_index=None,
55
  return_latents=False):
56
- """Forward function for StyleGAN2Generator.
57
 
58
  Args:
59
  styles (list[Tensor]): Sample codes of styles.
60
- input_is_latent (bool): Whether input is latent style.
61
- Default: False.
62
  noise (Tensor | None): Input noise or None. Default: None.
63
- randomize_noise (bool): Randomize noise, used when 'noise' is
64
- False. Default: True.
65
- truncation (float): TODO. Default: 1.
66
- truncation_latent (Tensor | None): TODO. Default: None.
67
- inject_index (int | None): The injection index for mixing noise.
68
- Default: None.
69
- return_latents (bool): Whether to return style latents.
70
- Default: False.
71
  """
72
  # style codes -> latents with Style MLP layer
73
  if not input_is_latent:
@@ -84,7 +81,7 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
84
  for style in styles:
85
  style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
86
  styles = style_truncation
87
- # get style latent with injection
88
  if len(styles) == 1:
89
  inject_index = self.num_latent
90
 
@@ -113,15 +110,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
113
  # the conditions may have fewer levels
114
  if i < len(conditions):
115
  # SFT part to combine the conditions
116
- if self.sft_half:
117
  out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
118
  out_sft = out_sft * conditions[i - 1] + conditions[i]
119
  out = torch.cat([out_same, out_sft], dim=1)
120
- else:
121
  out = out * conditions[i - 1] + conditions[i]
122
 
123
  out = conv2(out, latent[:, i + 1], noise=noise2)
124
- skip = to_rgb(out, latent[:, i + 2], skip)
125
  i += 2
126
 
127
  image = skip
@@ -133,17 +130,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
133
 
134
 
135
  class ConvUpLayer(nn.Module):
136
- """Conv Up Layer. Bilinear upsample + Conv.
137
 
138
  Args:
139
  in_channels (int): Channel number of the input.
140
  out_channels (int): Channel number of the output.
141
  kernel_size (int): Size of the convolving kernel.
142
  stride (int): Stride of the convolution. Default: 1
143
- padding (int): Zero-padding added to both sides of the input.
144
- Default: 0.
145
- bias (bool): If ``True``, adds a learnable bias to the output.
146
- Default: ``True``.
147
  bias_init_val (float): Bias initialized value. Default: 0.
148
  activate (bool): Whether use activateion. Default: True.
149
  """
@@ -163,6 +158,7 @@ class ConvUpLayer(nn.Module):
163
  self.kernel_size = kernel_size
164
  self.stride = stride
165
  self.padding = padding
 
166
  self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
167
 
168
  self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
@@ -223,7 +219,26 @@ class ResUpBlock(nn.Module):
223
 
224
  @ARCH_REGISTRY.register()
225
  class GFPGANv1(nn.Module):
226
- """Unet + StyleGAN2 decoder with SFT."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  def __init__(
229
  self,
@@ -246,7 +261,7 @@ class GFPGANv1(nn.Module):
246
  self.different_w = different_w
247
  self.num_style_feat = num_style_feat
248
 
249
- unet_narrow = narrow * 0.5
250
  channels = {
251
  '4': int(512 * unet_narrow),
252
  '8': int(512 * unet_narrow),
@@ -295,6 +310,7 @@ class GFPGANv1(nn.Module):
295
  self.final_linear = EqualLinear(
296
  channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
297
 
 
298
  self.stylegan_decoder = StyleGAN2GeneratorSFT(
299
  out_size=out_size,
300
  num_style_feat=num_style_feat,
@@ -305,14 +321,16 @@ class GFPGANv1(nn.Module):
305
  narrow=narrow,
306
  sft_half=sft_half)
307
 
 
308
  if decoder_load_path:
309
  self.stylegan_decoder.load_state_dict(
310
  torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
 
311
  if fix_decoder:
312
  for _, param in self.stylegan_decoder.named_parameters():
313
  param.requires_grad = False
314
 
315
- # for SFT
316
  self.condition_scale = nn.ModuleList()
317
  self.condition_shift = nn.ModuleList()
318
  for i in range(3, self.log_size + 1):
@@ -332,13 +350,15 @@ class GFPGANv1(nn.Module):
332
  ScaledLeakyReLU(0.2),
333
  EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
334
 
335
- def forward(self,
336
- x,
337
- return_latents=False,
338
- save_feat_path=None,
339
- load_feat_path=None,
340
- return_rgb=True,
341
- randomize_noise=True):
 
 
342
  conditions = []
343
  unet_skips = []
344
  out_rgbs = []
@@ -362,7 +382,7 @@ class GFPGANv1(nn.Module):
362
  feat = feat + unet_skips[i]
363
  # ResUpLayer
364
  feat = self.conv_body_up[i](feat)
365
- # generate scale and shift for SFT layer
366
  scale = self.condition_scale[i](feat)
367
  conditions.append(scale.clone())
368
  shift = self.condition_shift[i](feat)
@@ -371,12 +391,6 @@ class GFPGANv1(nn.Module):
371
  if return_rgb:
372
  out_rgbs.append(self.toRGB[i](feat))
373
 
374
- if save_feat_path is not None:
375
- torch.save(conditions, save_feat_path)
376
- if load_feat_path is not None:
377
- conditions = torch.load(load_feat_path)
378
- conditions = [v.cuda() for v in conditions]
379
-
380
  # decoder
381
  image, _ = self.stylegan_decoder([style_code],
382
  conditions,
@@ -389,10 +403,12 @@ class GFPGANv1(nn.Module):
389
 
390
  @ARCH_REGISTRY.register()
391
  class FacialComponentDiscriminator(nn.Module):
 
 
392
 
393
  def __init__(self):
394
  super(FacialComponentDiscriminator, self).__init__()
395
-
396
  self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
397
  self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
398
  self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
@@ -401,6 +417,12 @@ class FacialComponentDiscriminator(nn.Module):
401
  self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
402
 
403
  def forward(self, x, return_feats=False):
 
 
 
 
 
 
404
  feat = self.conv1(x)
405
  feat = self.conv3(self.conv2(feat))
406
  rlt_feats = []
 
10
 
11
 
12
  class StyleGAN2GeneratorSFT(StyleGAN2Generator):
13
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
 
15
  Args:
16
  out_size (int): The spatial size of outputs.
17
  num_style_feat (int): Channel number of style features. Default: 512.
18
  num_mlp (int): Layer number of MLP style layers. Default: 8.
19
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
20
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
21
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
 
 
22
  lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
23
+ narrow (float): The narrow ratio for channels. Default: 1.
24
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
25
  """
26
 
27
  def __init__(self,
 
53
  truncation_latent=None,
54
  inject_index=None,
55
  return_latents=False):
56
+ """Forward function for StyleGAN2GeneratorSFT.
57
 
58
  Args:
59
  styles (list[Tensor]): Sample codes of styles.
60
+ conditions (list[Tensor]): SFT conditions to generators.
61
+ input_is_latent (bool): Whether input is latent style. Default: False.
62
  noise (Tensor | None): Input noise or None. Default: None.
63
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
64
+ truncation (float): The truncation ratio. Default: 1.
65
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
66
+ inject_index (int | None): The injection index for mixing noise. Default: None.
67
+ return_latents (bool): Whether to return style latents. Default: False.
 
 
 
68
  """
69
  # style codes -> latents with Style MLP layer
70
  if not input_is_latent:
 
81
  for style in styles:
82
  style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
83
  styles = style_truncation
84
+ # get style latents with injection
85
  if len(styles) == 1:
86
  inject_index = self.num_latent
87
 
 
110
  # the conditions may have fewer levels
111
  if i < len(conditions):
112
  # SFT part to combine the conditions
113
+ if self.sft_half: # only apply SFT to half of the channels
114
  out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
115
  out_sft = out_sft * conditions[i - 1] + conditions[i]
116
  out = torch.cat([out_same, out_sft], dim=1)
117
+ else: # apply SFT to all the channels
118
  out = out * conditions[i - 1] + conditions[i]
119
 
120
  out = conv2(out, latent[:, i + 1], noise=noise2)
121
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
122
  i += 2
123
 
124
  image = skip
 
130
 
131
 
132
  class ConvUpLayer(nn.Module):
133
+ """Convolutional upsampling layer. It uses bilinear upsampler + Conv.
134
 
135
  Args:
136
  in_channels (int): Channel number of the input.
137
  out_channels (int): Channel number of the output.
138
  kernel_size (int): Size of the convolving kernel.
139
  stride (int): Stride of the convolution. Default: 1
140
+ padding (int): Zero-padding added to both sides of the input. Default: 0.
141
+ bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
 
 
142
  bias_init_val (float): Bias initialized value. Default: 0.
143
  activate (bool): Whether use activateion. Default: True.
144
  """
 
158
  self.kernel_size = kernel_size
159
  self.stride = stride
160
  self.padding = padding
161
+ # self.scale is used to scale the convolution weights, which is related to the common initializations.
162
  self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
163
 
164
  self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
 
219
 
220
  @ARCH_REGISTRY.register()
221
  class GFPGANv1(nn.Module):
222
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
223
+
224
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
225
+
226
+ Args:
227
+ out_size (int): The spatial size of outputs.
228
+ num_style_feat (int): Channel number of style features. Default: 512.
229
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
230
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
231
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
232
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
233
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
234
+
235
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
236
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
237
+ input_is_latent (bool): Whether input is latent style. Default: False.
238
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
239
+ narrow (float): The narrow ratio for channels. Default: 1.
240
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
241
+ """
242
 
243
  def __init__(
244
  self,
 
261
  self.different_w = different_w
262
  self.num_style_feat = num_style_feat
263
 
264
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
265
  channels = {
266
  '4': int(512 * unet_narrow),
267
  '8': int(512 * unet_narrow),
 
310
  self.final_linear = EqualLinear(
311
  channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
312
 
313
+ # the decoder: stylegan2 generator with SFT modulations
314
  self.stylegan_decoder = StyleGAN2GeneratorSFT(
315
  out_size=out_size,
316
  num_style_feat=num_style_feat,
 
321
  narrow=narrow,
322
  sft_half=sft_half)
323
 
324
+ # load pre-trained stylegan2 model if necessary
325
  if decoder_load_path:
326
  self.stylegan_decoder.load_state_dict(
327
  torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
328
+ # fix decoder without updating params
329
  if fix_decoder:
330
  for _, param in self.stylegan_decoder.named_parameters():
331
  param.requires_grad = False
332
 
333
+ # for SFT modulations (scale and shift)
334
  self.condition_scale = nn.ModuleList()
335
  self.condition_shift = nn.ModuleList()
336
  for i in range(3, self.log_size + 1):
 
350
  ScaledLeakyReLU(0.2),
351
  EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
352
 
353
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
354
+ """Forward function for GFPGANv1.
355
+
356
+ Args:
357
+ x (Tensor): Input images.
358
+ return_latents (bool): Whether to return style latents. Default: False.
359
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
360
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
361
+ """
362
  conditions = []
363
  unet_skips = []
364
  out_rgbs = []
 
382
  feat = feat + unet_skips[i]
383
  # ResUpLayer
384
  feat = self.conv_body_up[i](feat)
385
+ # generate scale and shift for SFT layers
386
  scale = self.condition_scale[i](feat)
387
  conditions.append(scale.clone())
388
  shift = self.condition_shift[i](feat)
 
391
  if return_rgb:
392
  out_rgbs.append(self.toRGB[i](feat))
393
 
 
 
 
 
 
 
394
  # decoder
395
  image, _ = self.stylegan_decoder([style_code],
396
  conditions,
 
403
 
404
  @ARCH_REGISTRY.register()
405
  class FacialComponentDiscriminator(nn.Module):
406
+ """Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
407
+ """
408
 
409
  def __init__(self):
410
  super(FacialComponentDiscriminator, self).__init__()
411
+ # It now uses a VGG-style architectrue with fixed model size
412
  self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
413
  self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
414
  self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
 
417
  self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
418
 
419
  def forward(self, x, return_feats=False):
420
+ """Forward function for FacialComponentDiscriminator.
421
+
422
+ Args:
423
+ x (Tensor): Input images.
424
+ return_feats (bool): Whether to return intermediate features. Default: False.
425
+ """
426
  feat = self.conv1(x)
427
  feat = self.conv3(self.conv2(feat))
428
  rlt_feats = []
gfpgan/archs/gfpganv1_clean_arch.py CHANGED
@@ -1,6 +1,7 @@
1
  import math
2
  import random
3
  import torch
 
4
  from torch import nn
5
  from torch.nn import functional as F
6
 
@@ -8,14 +9,17 @@ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
8
 
9
 
10
  class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
11
- """StyleGAN2 Generator.
 
 
12
 
13
  Args:
14
  out_size (int): The spatial size of outputs.
15
  num_style_feat (int): Channel number of style features. Default: 512.
16
  num_mlp (int): Layer number of MLP style layers. Default: 8.
17
- channel_multiplier (int): Channel multiplier for large networks of
18
- StyleGAN2. Default: 2.
 
19
  """
20
 
21
  def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
@@ -25,7 +29,6 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
25
  num_mlp=num_mlp,
26
  channel_multiplier=channel_multiplier,
27
  narrow=narrow)
28
-
29
  self.sft_half = sft_half
30
 
31
  def forward(self,
@@ -38,21 +41,18 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
38
  truncation_latent=None,
39
  inject_index=None,
40
  return_latents=False):
41
- """Forward function for StyleGAN2Generator.
42
 
43
  Args:
44
  styles (list[Tensor]): Sample codes of styles.
45
- input_is_latent (bool): Whether input is latent style.
46
- Default: False.
47
  noise (Tensor | None): Input noise or None. Default: None.
48
- randomize_noise (bool): Randomize noise, used when 'noise' is
49
- False. Default: True.
50
- truncation (float): TODO. Default: 1.
51
- truncation_latent (Tensor | None): TODO. Default: None.
52
- inject_index (int | None): The injection index for mixing noise.
53
- Default: None.
54
- return_latents (bool): Whether to return style latents.
55
- Default: False.
56
  """
57
  # style codes -> latents with Style MLP layer
58
  if not input_is_latent:
@@ -69,7 +69,7 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
69
  for style in styles:
70
  style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
71
  styles = style_truncation
72
- # get style latent with injection
73
  if len(styles) == 1:
74
  inject_index = self.num_latent
75
 
@@ -98,15 +98,15 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
98
  # the conditions may have fewer levels
99
  if i < len(conditions):
100
  # SFT part to combine the conditions
101
- if self.sft_half:
102
  out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
103
  out_sft = out_sft * conditions[i - 1] + conditions[i]
104
  out = torch.cat([out_same, out_sft], dim=1)
105
- else:
106
  out = out * conditions[i - 1] + conditions[i]
107
 
108
  out = conv2(out, latent[:, i + 1], noise=noise2)
109
- skip = to_rgb(out, latent[:, i + 2], skip)
110
  i += 2
111
 
112
  image = skip
@@ -118,11 +118,12 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
118
 
119
 
120
  class ResBlock(nn.Module):
121
- """Residual block with upsampling/downsampling.
122
 
123
  Args:
124
  in_channels (int): Channel number of the input.
125
  out_channels (int): Channel number of the output.
 
126
  """
127
 
128
  def __init__(self, in_channels, out_channels, mode='down'):
@@ -148,8 +149,27 @@ class ResBlock(nn.Module):
148
  return out
149
 
150
 
 
151
  class GFPGANv1Clean(nn.Module):
152
- """GFPGANv1 Clean version."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def __init__(
155
  self,
@@ -170,7 +190,7 @@ class GFPGANv1Clean(nn.Module):
170
  self.different_w = different_w
171
  self.num_style_feat = num_style_feat
172
 
173
- unet_narrow = narrow * 0.5
174
  channels = {
175
  '4': int(512 * unet_narrow),
176
  '8': int(512 * unet_narrow),
@@ -218,6 +238,7 @@ class GFPGANv1Clean(nn.Module):
218
 
219
  self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
220
 
 
221
  self.stylegan_decoder = StyleGAN2GeneratorCSFT(
222
  out_size=out_size,
223
  num_style_feat=num_style_feat,
@@ -226,14 +247,16 @@ class GFPGANv1Clean(nn.Module):
226
  narrow=narrow,
227
  sft_half=sft_half)
228
 
 
229
  if decoder_load_path:
230
  self.stylegan_decoder.load_state_dict(
231
  torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
 
232
  if fix_decoder:
233
- for name, param in self.stylegan_decoder.named_parameters():
234
  param.requires_grad = False
235
 
236
- # for SFT
237
  self.condition_scale = nn.ModuleList()
238
  self.condition_shift = nn.ModuleList()
239
  for i in range(3, self.log_size + 1):
@@ -251,13 +274,15 @@ class GFPGANv1Clean(nn.Module):
251
  nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
252
  nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
253
 
254
- def forward(self,
255
- x,
256
- return_latents=False,
257
- save_feat_path=None,
258
- load_feat_path=None,
259
- return_rgb=True,
260
- randomize_noise=True):
 
 
261
  conditions = []
262
  unet_skips = []
263
  out_rgbs = []
@@ -273,13 +298,14 @@ class GFPGANv1Clean(nn.Module):
273
  style_code = self.final_linear(feat.view(feat.size(0), -1))
274
  if self.different_w:
275
  style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
 
276
  # decode
277
  for i in range(self.log_size - 2):
278
  # add unet skip
279
  feat = feat + unet_skips[i]
280
  # ResUpLayer
281
  feat = self.conv_body_up[i](feat)
282
- # generate scale and shift for SFT layer
283
  scale = self.condition_scale[i](feat)
284
  conditions.append(scale.clone())
285
  shift = self.condition_shift[i](feat)
@@ -288,12 +314,6 @@ class GFPGANv1Clean(nn.Module):
288
  if return_rgb:
289
  out_rgbs.append(self.toRGB[i](feat))
290
 
291
- if save_feat_path is not None:
292
- torch.save(conditions, save_feat_path)
293
- if load_feat_path is not None:
294
- conditions = torch.load(load_feat_path)
295
- conditions = [v.cuda() for v in conditions]
296
-
297
  # decoder
298
  image, _ = self.stylegan_decoder([style_code],
299
  conditions,
 
1
  import math
2
  import random
3
  import torch
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
  from torch import nn
6
  from torch.nn import functional as F
7
 
 
9
 
10
 
11
  class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
12
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
13
+
14
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
15
 
16
  Args:
17
  out_size (int): The spatial size of outputs.
18
  num_style_feat (int): Channel number of style features. Default: 512.
19
  num_mlp (int): Layer number of MLP style layers. Default: 8.
20
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
21
+ narrow (float): The narrow ratio for channels. Default: 1.
22
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
23
  """
24
 
25
  def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
 
29
  num_mlp=num_mlp,
30
  channel_multiplier=channel_multiplier,
31
  narrow=narrow)
 
32
  self.sft_half = sft_half
33
 
34
  def forward(self,
 
41
  truncation_latent=None,
42
  inject_index=None,
43
  return_latents=False):
44
+ """Forward function for StyleGAN2GeneratorCSFT.
45
 
46
  Args:
47
  styles (list[Tensor]): Sample codes of styles.
48
+ conditions (list[Tensor]): SFT conditions to generators.
49
+ input_is_latent (bool): Whether input is latent style. Default: False.
50
  noise (Tensor | None): Input noise or None. Default: None.
51
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
52
+ truncation (float): The truncation ratio. Default: 1.
53
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
54
+ inject_index (int | None): The injection index for mixing noise. Default: None.
55
+ return_latents (bool): Whether to return style latents. Default: False.
 
 
 
56
  """
57
  # style codes -> latents with Style MLP layer
58
  if not input_is_latent:
 
69
  for style in styles:
70
  style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
71
  styles = style_truncation
72
+ # get style latents with injection
73
  if len(styles) == 1:
74
  inject_index = self.num_latent
75
 
 
98
  # the conditions may have fewer levels
99
  if i < len(conditions):
100
  # SFT part to combine the conditions
101
+ if self.sft_half: # only apply SFT to half of the channels
102
  out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
103
  out_sft = out_sft * conditions[i - 1] + conditions[i]
104
  out = torch.cat([out_same, out_sft], dim=1)
105
+ else: # apply SFT to all the channels
106
  out = out * conditions[i - 1] + conditions[i]
107
 
108
  out = conv2(out, latent[:, i + 1], noise=noise2)
109
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
110
  i += 2
111
 
112
  image = skip
 
118
 
119
 
120
  class ResBlock(nn.Module):
121
+ """Residual block with bilinear upsampling/downsampling.
122
 
123
  Args:
124
  in_channels (int): Channel number of the input.
125
  out_channels (int): Channel number of the output.
126
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
127
  """
128
 
129
  def __init__(self, in_channels, out_channels, mode='down'):
 
149
  return out
150
 
151
 
152
+ @ARCH_REGISTRY.register()
153
  class GFPGANv1Clean(nn.Module):
154
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
155
+
156
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
157
+
158
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
159
+
160
+ Args:
161
+ out_size (int): The spatial size of outputs.
162
+ num_style_feat (int): Channel number of style features. Default: 512.
163
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
164
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
165
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
166
+
167
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
168
+ input_is_latent (bool): Whether input is latent style. Default: False.
169
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
170
+ narrow (float): The narrow ratio for channels. Default: 1.
171
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
172
+ """
173
 
174
  def __init__(
175
  self,
 
190
  self.different_w = different_w
191
  self.num_style_feat = num_style_feat
192
 
193
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
194
  channels = {
195
  '4': int(512 * unet_narrow),
196
  '8': int(512 * unet_narrow),
 
238
 
239
  self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
240
 
241
+ # the decoder: stylegan2 generator with SFT modulations
242
  self.stylegan_decoder = StyleGAN2GeneratorCSFT(
243
  out_size=out_size,
244
  num_style_feat=num_style_feat,
 
247
  narrow=narrow,
248
  sft_half=sft_half)
249
 
250
+ # load pre-trained stylegan2 model if necessary
251
  if decoder_load_path:
252
  self.stylegan_decoder.load_state_dict(
253
  torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
254
+ # fix decoder without updating params
255
  if fix_decoder:
256
+ for _, param in self.stylegan_decoder.named_parameters():
257
  param.requires_grad = False
258
 
259
+ # for SFT modulations (scale and shift)
260
  self.condition_scale = nn.ModuleList()
261
  self.condition_shift = nn.ModuleList()
262
  for i in range(3, self.log_size + 1):
 
274
  nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
275
  nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
276
 
277
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
278
+ """Forward function for GFPGANv1Clean.
279
+
280
+ Args:
281
+ x (Tensor): Input images.
282
+ return_latents (bool): Whether to return style latents. Default: False.
283
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
284
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
285
+ """
286
  conditions = []
287
  unet_skips = []
288
  out_rgbs = []
 
298
  style_code = self.final_linear(feat.view(feat.size(0), -1))
299
  if self.different_w:
300
  style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
301
+
302
  # decode
303
  for i in range(self.log_size - 2):
304
  # add unet skip
305
  feat = feat + unet_skips[i]
306
  # ResUpLayer
307
  feat = self.conv_body_up[i](feat)
308
+ # generate scale and shift for SFT layers
309
  scale = self.condition_scale[i](feat)
310
  conditions.append(scale.clone())
311
  shift = self.condition_shift[i](feat)
 
314
  if return_rgb:
315
  out_rgbs.append(self.toRGB[i](feat))
316
 
 
 
 
 
 
 
317
  # decoder
318
  image, _ = self.stylegan_decoder([style_code],
319
  conditions,
gfpgan/archs/stylegan2_clean_arch.py CHANGED
@@ -31,12 +31,9 @@ class ModulatedConv2d(nn.Module):
31
  out_channels (int): Channel number of the output.
32
  kernel_size (int): Size of the convolving kernel.
33
  num_style_feat (int): Channel number of style features.
34
- demodulate (bool): Whether to demodulate in the conv layer.
35
- Default: True.
36
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
37
- Default: None.
38
- eps (float): A value added to the denominator for numerical stability.
39
- Default: 1e-8.
40
  """
41
 
42
  def __init__(self,
@@ -87,6 +84,7 @@ class ModulatedConv2d(nn.Module):
87
 
88
  weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
89
 
 
90
  if self.sample_mode == 'upsample':
91
  x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
92
  elif self.sample_mode == 'downsample':
@@ -101,14 +99,12 @@ class ModulatedConv2d(nn.Module):
101
  return out
102
 
103
  def __repr__(self):
104
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
105
- f'out_channels={self.out_channels}, '
106
- f'kernel_size={self.kernel_size}, '
107
- f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
108
 
109
 
110
  class StyleConv(nn.Module):
111
- """Style conv.
112
 
113
  Args:
114
  in_channels (int): Channel number of the input.
@@ -116,8 +112,7 @@ class StyleConv(nn.Module):
116
  kernel_size (int): Size of the convolving kernel.
117
  num_style_feat (int): Channel number of style features.
118
  demodulate (bool): Whether demodulate in the conv layer. Default: True.
119
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
120
- Default: None.
121
  """
122
 
123
  def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
@@ -144,7 +139,7 @@ class StyleConv(nn.Module):
144
 
145
 
146
  class ToRGB(nn.Module):
147
- """To RGB from features.
148
 
149
  Args:
150
  in_channels (int): Channel number of input.
@@ -204,8 +199,7 @@ class StyleGAN2GeneratorClean(nn.Module):
204
  out_size (int): The spatial size of outputs.
205
  num_style_feat (int): Channel number of style features. Default: 512.
206
  num_mlp (int): Layer number of MLP style layers. Default: 8.
207
- channel_multiplier (int): Channel multiplier for large networks of
208
- StyleGAN2. Default: 2.
209
  narrow (float): Narrow ratio for channels. Default: 1.0.
210
  """
211
 
@@ -222,6 +216,7 @@ class StyleGAN2GeneratorClean(nn.Module):
222
  # initialization
223
  default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
224
 
 
225
  channels = {
226
  '4': int(512 * narrow),
227
  '8': int(512 * narrow),
@@ -309,21 +304,17 @@ class StyleGAN2GeneratorClean(nn.Module):
309
  truncation_latent=None,
310
  inject_index=None,
311
  return_latents=False):
312
- """Forward function for StyleGAN2Generator.
313
 
314
  Args:
315
  styles (list[Tensor]): Sample codes of styles.
316
- input_is_latent (bool): Whether input is latent style.
317
- Default: False.
318
  noise (Tensor | None): Input noise or None. Default: None.
319
- randomize_noise (bool): Randomize noise, used when 'noise' is
320
- False. Default: True.
321
- truncation (float): TODO. Default: 1.
322
- truncation_latent (Tensor | None): TODO. Default: None.
323
- inject_index (int | None): The injection index for mixing noise.
324
- Default: None.
325
- return_latents (bool): Whether to return style latents.
326
- Default: False.
327
  """
328
  # style codes -> latents with Style MLP layer
329
  if not input_is_latent:
@@ -340,7 +331,7 @@ class StyleGAN2GeneratorClean(nn.Module):
340
  for style in styles:
341
  style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
342
  styles = style_truncation
343
- # get style latent with injection
344
  if len(styles) == 1:
345
  inject_index = self.num_latent
346
 
@@ -366,7 +357,7 @@ class StyleGAN2GeneratorClean(nn.Module):
366
  noise[2::2], self.to_rgbs):
367
  out = conv1(out, latent[:, i], noise=noise1)
368
  out = conv2(out, latent[:, i + 1], noise=noise2)
369
- skip = to_rgb(out, latent[:, i + 2], skip)
370
  i += 2
371
 
372
  image = skip
 
31
  out_channels (int): Channel number of the output.
32
  kernel_size (int): Size of the convolving kernel.
33
  num_style_feat (int): Channel number of style features.
34
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
35
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
36
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
 
 
 
37
  """
38
 
39
  def __init__(self,
 
84
 
85
  weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
86
 
87
+ # upsample or downsample if necessary
88
  if self.sample_mode == 'upsample':
89
  x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
90
  elif self.sample_mode == 'downsample':
 
99
  return out
100
 
101
  def __repr__(self):
102
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
103
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
 
 
104
 
105
 
106
  class StyleConv(nn.Module):
107
+ """Style conv used in StyleGAN2.
108
 
109
  Args:
110
  in_channels (int): Channel number of the input.
 
112
  kernel_size (int): Size of the convolving kernel.
113
  num_style_feat (int): Channel number of style features.
114
  demodulate (bool): Whether demodulate in the conv layer. Default: True.
115
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
 
116
  """
117
 
118
  def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
 
139
 
140
 
141
  class ToRGB(nn.Module):
142
+ """To RGB (image space) from features.
143
 
144
  Args:
145
  in_channels (int): Channel number of input.
 
199
  out_size (int): The spatial size of outputs.
200
  num_style_feat (int): Channel number of style features. Default: 512.
201
  num_mlp (int): Layer number of MLP style layers. Default: 8.
202
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
 
203
  narrow (float): Narrow ratio for channels. Default: 1.0.
204
  """
205
 
 
216
  # initialization
217
  default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
218
 
219
+ # channel list
220
  channels = {
221
  '4': int(512 * narrow),
222
  '8': int(512 * narrow),
 
304
  truncation_latent=None,
305
  inject_index=None,
306
  return_latents=False):
307
+ """Forward function for StyleGAN2GeneratorClean.
308
 
309
  Args:
310
  styles (list[Tensor]): Sample codes of styles.
311
+ input_is_latent (bool): Whether input is latent style. Default: False.
 
312
  noise (Tensor | None): Input noise or None. Default: None.
313
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
314
+ truncation (float): The truncation ratio. Default: 1.
315
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
316
+ inject_index (int | None): The injection index for mixing noise. Default: None.
317
+ return_latents (bool): Whether to return style latents. Default: False.
 
 
 
318
  """
319
  # style codes -> latents with Style MLP layer
320
  if not input_is_latent:
 
331
  for style in styles:
332
  style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
333
  styles = style_truncation
334
+ # get style latents with injection
335
  if len(styles) == 1:
336
  inject_index = self.num_latent
337
 
 
357
  noise[2::2], self.to_rgbs):
358
  out = conv1(out, latent[:, i], noise=noise1)
359
  out = conv2(out, latent[:, i + 1], noise=noise2)
360
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
361
  i += 2
362
 
363
  image = skip
gfpgan/data/ffhq_degradation_dataset.py CHANGED
@@ -15,6 +15,19 @@ from torchvision.transforms.functional import (adjust_brightness, adjust_contras
15
 
16
  @DATASET_REGISTRY.register()
17
  class FFHQDegradationDataset(data.Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def __init__(self, opt):
20
  super(FFHQDegradationDataset, self).__init__()
@@ -29,11 +42,13 @@ class FFHQDegradationDataset(data.Dataset):
29
  self.out_size = opt['out_size']
30
 
31
  self.crop_components = opt.get('crop_components', False) # facial components
32
- self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
33
 
34
  if self.crop_components:
 
35
  self.components_list = torch.load(opt.get('component_path'))
36
 
 
37
  if self.io_backend_opt['type'] == 'lmdb':
38
  self.io_backend_opt['db_paths'] = self.gt_folder
39
  if not self.gt_folder.endswith('.lmdb'):
@@ -41,9 +56,10 @@ class FFHQDegradationDataset(data.Dataset):
41
  with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
42
  self.paths = [line.split('.')[0] for line in fin]
43
  else:
 
44
  self.paths = paths_from_folder(self.gt_folder)
45
 
46
- # degradations
47
  self.blur_kernel_size = opt['blur_kernel_size']
48
  self.kernel_list = opt['kernel_list']
49
  self.kernel_prob = opt['kernel_prob']
@@ -60,22 +76,20 @@ class FFHQDegradationDataset(data.Dataset):
60
  self.gray_prob = opt.get('gray_prob')
61
 
62
  logger = get_root_logger()
63
- logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
64
- f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
65
  logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
66
  logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
67
  logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
68
 
69
  if self.color_jitter_prob is not None:
70
- logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
71
- f'shift: {self.color_jitter_shift}')
72
  if self.gray_prob is not None:
73
  logger.info(f'Use random gray. Prob: {self.gray_prob}')
74
-
75
  self.color_jitter_shift /= 255.
76
 
77
  @staticmethod
78
  def color_jitter(img, shift):
 
79
  jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
80
  img = img + jitter_val
81
  img = np.clip(img, 0, 1)
@@ -83,6 +97,7 @@ class FFHQDegradationDataset(data.Dataset):
83
 
84
  @staticmethod
85
  def color_jitter_pt(img, brightness, contrast, saturation, hue):
 
86
  fn_idx = torch.randperm(4)
87
  for fn_id in fn_idx:
88
  if fn_id == 0 and brightness is not None:
@@ -103,6 +118,7 @@ class FFHQDegradationDataset(data.Dataset):
103
  return img
104
 
105
  def get_component_coordinates(self, index, status):
 
106
  components_bbox = self.components_list[f'{index:08d}']
107
  if status[0]: # hflip
108
  # exchange right and left eye
@@ -131,6 +147,7 @@ class FFHQDegradationDataset(data.Dataset):
131
  self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
132
 
133
  # load gt image
 
134
  gt_path = self.paths[index]
135
  img_bytes = self.file_client.get(gt_path)
136
  img_gt = imfrombytes(img_bytes, float32=True)
@@ -139,6 +156,7 @@ class FFHQDegradationDataset(data.Dataset):
139
  img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
140
  h, w, _ = img_gt.shape
141
 
 
142
  if self.crop_components:
143
  locations = self.get_component_coordinates(index, status)
144
  loc_left_eye, loc_right_eye, loc_mouth = locations
@@ -173,9 +191,9 @@ class FFHQDegradationDataset(data.Dataset):
173
  if self.gray_prob and np.random.uniform() < self.gray_prob:
174
  img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
175
  img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
176
- if self.opt.get('gt_gray'):
177
  img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
178
- img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
179
 
180
  # BGR to RGB, HWC to CHW, numpy to tensor
181
  img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
 
15
 
16
  @DATASET_REGISTRY.register()
17
  class FFHQDegradationDataset(data.Dataset):
18
+ """FFHQ dataset for GFPGAN.
19
+
20
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
21
+
22
+ Args:
23
+ opt (dict): Config for train datasets. It contains the following keys:
24
+ dataroot_gt (str): Data root path for gt.
25
+ io_backend (dict): IO backend type and other kwarg.
26
+ mean (list | tuple): Image mean.
27
+ std (list | tuple): Image std.
28
+ use_hflip (bool): Whether to horizontally flip.
29
+ Please see more options in the codes.
30
+ """
31
 
32
  def __init__(self, opt):
33
  super(FFHQDegradationDataset, self).__init__()
 
42
  self.out_size = opt['out_size']
43
 
44
  self.crop_components = opt.get('crop_components', False) # facial components
45
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
46
 
47
  if self.crop_components:
48
+ # load component list from a pre-process pth files
49
  self.components_list = torch.load(opt.get('component_path'))
50
 
51
+ # file client (lmdb io backend)
52
  if self.io_backend_opt['type'] == 'lmdb':
53
  self.io_backend_opt['db_paths'] = self.gt_folder
54
  if not self.gt_folder.endswith('.lmdb'):
 
56
  with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
57
  self.paths = [line.split('.')[0] for line in fin]
58
  else:
59
+ # disk backend: scan file list from a folder
60
  self.paths = paths_from_folder(self.gt_folder)
61
 
62
+ # degradation configurations
63
  self.blur_kernel_size = opt['blur_kernel_size']
64
  self.kernel_list = opt['kernel_list']
65
  self.kernel_prob = opt['kernel_prob']
 
76
  self.gray_prob = opt.get('gray_prob')
77
 
78
  logger = get_root_logger()
79
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
 
80
  logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
81
  logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
82
  logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
83
 
84
  if self.color_jitter_prob is not None:
85
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
 
86
  if self.gray_prob is not None:
87
  logger.info(f'Use random gray. Prob: {self.gray_prob}')
 
88
  self.color_jitter_shift /= 255.
89
 
90
  @staticmethod
91
  def color_jitter(img, shift):
92
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
93
  jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
94
  img = img + jitter_val
95
  img = np.clip(img, 0, 1)
 
97
 
98
  @staticmethod
99
  def color_jitter_pt(img, brightness, contrast, saturation, hue):
100
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
101
  fn_idx = torch.randperm(4)
102
  for fn_id in fn_idx:
103
  if fn_id == 0 and brightness is not None:
 
118
  return img
119
 
120
  def get_component_coordinates(self, index, status):
121
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
122
  components_bbox = self.components_list[f'{index:08d}']
123
  if status[0]: # hflip
124
  # exchange right and left eye
 
147
  self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
 
149
  # load gt image
150
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
151
  gt_path = self.paths[index]
152
  img_bytes = self.file_client.get(gt_path)
153
  img_gt = imfrombytes(img_bytes, float32=True)
 
156
  img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
157
  h, w, _ = img_gt.shape
158
 
159
+ # get facial component coordinates
160
  if self.crop_components:
161
  locations = self.get_component_coordinates(index, status)
162
  loc_left_eye, loc_right_eye, loc_mouth = locations
 
191
  if self.gray_prob and np.random.uniform() < self.gray_prob:
192
  img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
193
  img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
194
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
195
  img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
196
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
197
 
198
  # BGR to RGB, HWC to CHW, numpy to tensor
199
  img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
gfpgan/models/gfpgan_model.py CHANGED
@@ -16,11 +16,11 @@ from tqdm import tqdm
16
 
17
  @MODEL_REGISTRY.register()
18
  class GFPGANModel(BaseModel):
19
- """GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
20
 
21
  def __init__(self, opt):
22
  super(GFPGANModel, self).__init__(opt)
23
- self.idx = 0
24
 
25
  # define network
26
  self.net_g = build_network(opt['network_g'])
@@ -51,8 +51,7 @@ class GFPGANModel(BaseModel):
51
  self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52
 
53
  # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54
- # net_g_ema only used for testing on one GPU and saving
55
- # There is no need to wrap with DistributedDataParallel
56
  self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
57
  # load pretrained model
58
  load_path = self.opt['path'].get('pretrain_network_g', None)
@@ -65,7 +64,7 @@ class GFPGANModel(BaseModel):
65
  self.net_d.train()
66
  self.net_g_ema.eval()
67
 
68
- # ----------- facial components networks ----------- #
69
  if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
70
  self.use_facial_disc = True
71
  else:
@@ -102,17 +101,19 @@ class GFPGANModel(BaseModel):
102
  self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
103
 
104
  # ----------- define losses ----------- #
 
105
  if train_opt.get('pixel_opt'):
106
  self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107
  else:
108
  self.cri_pix = None
109
 
 
110
  if train_opt.get('perceptual_opt'):
111
  self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
112
  else:
113
  self.cri_perceptual = None
114
 
115
- # L1 loss used in pyramid loss, component style loss and identity loss
116
  self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
117
 
118
  # gan loss (wgan)
@@ -179,6 +180,7 @@ class GFPGANModel(BaseModel):
179
  self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
180
  self.optimizers.append(self.optimizer_d)
181
 
 
182
  if self.use_facial_disc:
183
  # setup optimizers for facial component discriminators
184
  optim_type = train_opt['optim_component'].pop('type')
@@ -221,6 +223,7 @@ class GFPGANModel(BaseModel):
221
  # self.idx = self.idx + 1
222
 
223
  def construct_img_pyramid(self):
 
224
  pyramid_gt = [self.gt]
225
  down_img = self.gt
226
  for _ in range(0, self.log_size - 3):
@@ -229,7 +232,6 @@ class GFPGANModel(BaseModel):
229
  return pyramid_gt
230
 
231
  def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
232
- # hard code
233
  face_ratio = int(self.opt['network_g']['out_size'] / 512)
234
  eye_out_size *= face_ratio
235
  mouth_out_size *= face_ratio
@@ -288,6 +290,7 @@ class GFPGANModel(BaseModel):
288
  p.requires_grad = False
289
  self.optimizer_g.zero_grad()
290
 
 
291
  if self.use_facial_disc:
292
  for p in self.net_d_left_eye.parameters():
293
  p.requires_grad = False
@@ -419,11 +422,12 @@ class GFPGANModel(BaseModel):
419
  real_d_pred = self.net_d(self.gt)
420
  l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
421
  loss_dict['l_d'] = l_d
422
- # In wgan, real_score should be positive and fake_score should benegative
423
  loss_dict['real_score'] = real_d_pred.detach().mean()
424
  loss_dict['fake_score'] = fake_d_pred.detach().mean()
425
  l_d.backward()
426
 
 
427
  if current_iter % self.net_d_reg_every == 0:
428
  self.gt.requires_grad = True
429
  real_pred = self.net_d(self.gt)
@@ -434,8 +438,9 @@ class GFPGANModel(BaseModel):
434
 
435
  self.optimizer_d.step()
436
 
 
437
  if self.use_facial_disc:
438
- # lefe eye
439
  fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
440
  real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
441
  l_d_left_eye = self.cri_component(
@@ -485,22 +490,32 @@ class GFPGANModel(BaseModel):
485
  def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
486
  dataset_name = dataloader.dataset.opt['name']
487
  with_metrics = self.opt['val'].get('metrics') is not None
 
 
488
  if with_metrics:
489
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
490
- pbar = tqdm(total=len(dataloader), unit='image')
 
 
 
 
 
 
 
 
491
 
492
  for idx, val_data in enumerate(dataloader):
493
  img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
494
  self.feed_data(val_data)
495
  self.test()
496
 
497
- visuals = self.get_current_visuals()
498
- sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
499
- gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
500
-
501
- if 'gt' in visuals:
502
- gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
503
  del self.gt
 
504
  # tentative for out of GPU memory
505
  del self.lq
506
  del self.output
@@ -522,35 +537,38 @@ class GFPGANModel(BaseModel):
522
  if with_metrics:
523
  # calculate metrics
524
  for name, opt_ in self.opt['val']['metrics'].items():
525
- metric_data = dict(img1=sr_img, img2=gt_img)
526
  self.metric_results[name] += calculate_metric(metric_data, opt_)
527
- pbar.update(1)
528
- pbar.set_description(f'Test {img_name}')
529
- pbar.close()
 
 
530
 
531
  if with_metrics:
532
  for metric in self.metric_results.keys():
533
  self.metric_results[metric] /= (idx + 1)
 
 
534
 
535
  self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
536
 
537
  def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
538
  log_str = f'Validation {dataset_name}\n'
539
  for metric, value in self.metric_results.items():
540
- log_str += f'\t # {metric}: {value:.4f}\n'
 
 
 
 
 
541
  logger = get_root_logger()
542
  logger.info(log_str)
543
  if tb_logger:
544
  for metric, value in self.metric_results.items():
545
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
546
-
547
- def get_current_visuals(self):
548
- out_dict = OrderedDict()
549
- out_dict['gt'] = self.gt.detach().cpu()
550
- out_dict['sr'] = self.output.detach().cpu()
551
- return out_dict
552
 
553
  def save(self, epoch, current_iter):
 
554
  self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
555
  self.save_network(self.net_d, 'net_d', current_iter)
556
  # save component discriminators
@@ -558,4 +576,5 @@ class GFPGANModel(BaseModel):
558
  self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
559
  self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
560
  self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
 
561
  self.save_training_state(epoch, current_iter)
 
16
 
17
  @MODEL_REGISTRY.register()
18
  class GFPGANModel(BaseModel):
19
+ """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
20
 
21
  def __init__(self, opt):
22
  super(GFPGANModel, self).__init__(opt)
23
+ self.idx = 0 # it is used for saving data for check
24
 
25
  # define network
26
  self.net_g = build_network(opt['network_g'])
 
51
  self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52
 
53
  # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54
+ # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
 
55
  self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
56
  # load pretrained model
57
  load_path = self.opt['path'].get('pretrain_network_g', None)
 
64
  self.net_d.train()
65
  self.net_g_ema.eval()
66
 
67
+ # ----------- facial component networks ----------- #
68
  if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
69
  self.use_facial_disc = True
70
  else:
 
101
  self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
102
 
103
  # ----------- define losses ----------- #
104
+ # pixel loss
105
  if train_opt.get('pixel_opt'):
106
  self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107
  else:
108
  self.cri_pix = None
109
 
110
+ # perceptual loss
111
  if train_opt.get('perceptual_opt'):
112
  self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
113
  else:
114
  self.cri_perceptual = None
115
 
116
+ # L1 loss is used in pyramid loss, component style loss and identity loss
117
  self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
118
 
119
  # gan loss (wgan)
 
180
  self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
181
  self.optimizers.append(self.optimizer_d)
182
 
183
+ # ----------- optimizers for facial component networks ----------- #
184
  if self.use_facial_disc:
185
  # setup optimizers for facial component discriminators
186
  optim_type = train_opt['optim_component'].pop('type')
 
223
  # self.idx = self.idx + 1
224
 
225
  def construct_img_pyramid(self):
226
+ """Construct image pyramid for intermediate restoration loss"""
227
  pyramid_gt = [self.gt]
228
  down_img = self.gt
229
  for _ in range(0, self.log_size - 3):
 
232
  return pyramid_gt
233
 
234
  def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
 
235
  face_ratio = int(self.opt['network_g']['out_size'] / 512)
236
  eye_out_size *= face_ratio
237
  mouth_out_size *= face_ratio
 
290
  p.requires_grad = False
291
  self.optimizer_g.zero_grad()
292
 
293
+ # do not update facial component net_d
294
  if self.use_facial_disc:
295
  for p in self.net_d_left_eye.parameters():
296
  p.requires_grad = False
 
422
  real_d_pred = self.net_d(self.gt)
423
  l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
424
  loss_dict['l_d'] = l_d
425
+ # In WGAN, real_score should be positive and fake_score should be negative
426
  loss_dict['real_score'] = real_d_pred.detach().mean()
427
  loss_dict['fake_score'] = fake_d_pred.detach().mean()
428
  l_d.backward()
429
 
430
+ # regularization loss
431
  if current_iter % self.net_d_reg_every == 0:
432
  self.gt.requires_grad = True
433
  real_pred = self.net_d(self.gt)
 
438
 
439
  self.optimizer_d.step()
440
 
441
+ # optimize facial component discriminators
442
  if self.use_facial_disc:
443
+ # left eye
444
  fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
445
  real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
446
  l_d_left_eye = self.cri_component(
 
490
  def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
491
  dataset_name = dataloader.dataset.opt['name']
492
  with_metrics = self.opt['val'].get('metrics') is not None
493
+ use_pbar = self.opt['val'].get('pbar', False)
494
+
495
  if with_metrics:
496
+ if not hasattr(self, 'metric_results'): # only execute in the first run
497
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
498
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
499
+ self._initialize_best_metric_results(dataset_name)
500
+ # zero self.metric_results
501
+ self.metric_results = {metric: 0 for metric in self.metric_results}
502
+
503
+ metric_data = dict()
504
+ if use_pbar:
505
+ pbar = tqdm(total=len(dataloader), unit='image')
506
 
507
  for idx, val_data in enumerate(dataloader):
508
  img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
509
  self.feed_data(val_data)
510
  self.test()
511
 
512
+ sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
513
+ metric_data['img'] = sr_img
514
+ if hasattr(self, 'gt'):
515
+ gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
516
+ metric_data['img2'] = gt_img
 
517
  del self.gt
518
+
519
  # tentative for out of GPU memory
520
  del self.lq
521
  del self.output
 
537
  if with_metrics:
538
  # calculate metrics
539
  for name, opt_ in self.opt['val']['metrics'].items():
 
540
  self.metric_results[name] += calculate_metric(metric_data, opt_)
541
+ if use_pbar:
542
+ pbar.update(1)
543
+ pbar.set_description(f'Test {img_name}')
544
+ if use_pbar:
545
+ pbar.close()
546
 
547
  if with_metrics:
548
  for metric in self.metric_results.keys():
549
  self.metric_results[metric] /= (idx + 1)
550
+ # update the best metric result
551
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
552
 
553
  self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
554
 
555
  def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
556
  log_str = f'Validation {dataset_name}\n'
557
  for metric, value in self.metric_results.items():
558
+ log_str += f'\t # {metric}: {value:.4f}'
559
+ if hasattr(self, 'best_metric_results'):
560
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
561
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
562
+ log_str += '\n'
563
+
564
  logger = get_root_logger()
565
  logger.info(log_str)
566
  if tb_logger:
567
  for metric, value in self.metric_results.items():
568
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
 
 
 
 
 
 
569
 
570
  def save(self, epoch, current_iter):
571
+ # save net_g and net_d
572
  self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
573
  self.save_network(self.net_d, 'net_d', current_iter)
574
  # save component discriminators
 
576
  self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
577
  self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
578
  self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
579
+ # save training state
580
  self.save_training_state(epoch, current_iter)
gfpgan/utils.py CHANGED
@@ -2,10 +2,9 @@ import cv2
2
  import os
3
  import torch
4
  from basicsr.utils import img2tensor, tensor2img
 
5
  from facexlib.utils.face_restoration_helper import FaceRestoreHelper
6
- from torch.hub import download_url_to_file, get_dir
7
  from torchvision.transforms.functional import normalize
8
- from urllib.parse import urlparse
9
 
10
  from gfpgan.archs.gfpganv1_arch import GFPGANv1
11
  from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
@@ -14,6 +13,20 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
 
15
 
16
  class GFPGANer():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
19
  self.upscale = upscale
@@ -56,7 +69,8 @@ class GFPGANer():
56
  device=self.device)
57
 
58
  if model_path.startswith('https://'):
59
- model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
 
60
  loadnet = torch.load(model_path)
61
  if 'params_ema' in loadnet:
62
  keyname = 'params_ema'
@@ -70,13 +84,15 @@ class GFPGANer():
70
  def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
71
  self.face_helper.clean_all()
72
 
73
- if has_aligned:
74
  img = cv2.resize(img, (512, 512))
75
  self.face_helper.cropped_faces = [img]
76
  else:
77
  self.face_helper.read_image(img)
78
  # get face landmarks for each face
79
- self.face_helper.get_face_landmarks_5(only_center_face=only_center_face)
 
 
80
  # align and warp each face
81
  self.face_helper.align_warp_face()
82
 
@@ -99,9 +115,9 @@ class GFPGANer():
99
  self.face_helper.add_restored_face(restored_face)
100
 
101
  if not has_aligned and paste_back:
102
-
103
  if self.bg_upsampler is not None:
104
- # Now only support RealESRGAN
105
  bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
106
  else:
107
  bg_img = None
@@ -112,23 +128,3 @@ class GFPGANer():
112
  return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
113
  else:
114
  return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
115
-
116
-
117
- def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
118
- """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
119
- """
120
- if model_dir is None:
121
- hub_dir = get_dir()
122
- model_dir = os.path.join(hub_dir, 'checkpoints')
123
-
124
- os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
125
-
126
- parts = urlparse(url)
127
- filename = os.path.basename(parts.path)
128
- if file_name is not None:
129
- filename = file_name
130
- cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
131
- if not os.path.exists(cached_file):
132
- print(f'Downloading: "{url}" to {cached_file}\n')
133
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
134
- return cached_file
 
2
  import os
3
  import torch
4
  from basicsr.utils import img2tensor, tensor2img
5
+ from basicsr.utils.download_util import load_file_from_url
6
  from facexlib.utils.face_restoration_helper import FaceRestoreHelper
 
7
  from torchvision.transforms.functional import normalize
 
8
 
9
  from gfpgan.archs.gfpganv1_arch import GFPGANv1
10
  from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
 
13
 
14
 
15
  class GFPGANer():
16
+ """Helper for restoration with GFPGAN.
17
+
18
+ It will detect and crop faces, and then resize the faces to 512x512.
19
+ GFPGAN is used to restored the resized faces.
20
+ The background is upsampled with the bg_upsampler.
21
+ Finally, the faces will be pasted back to the upsample background image.
22
+
23
+ Args:
24
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
25
+ upscale (float): The upscale of the final output. Default: 2.
26
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
27
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
28
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
29
+ """
30
 
31
  def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
32
  self.upscale = upscale
 
69
  device=self.device)
70
 
71
  if model_path.startswith('https://'):
72
+ model_path = load_file_from_url(
73
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
74
  loadnet = torch.load(model_path)
75
  if 'params_ema' in loadnet:
76
  keyname = 'params_ema'
 
84
  def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
85
  self.face_helper.clean_all()
86
 
87
+ if has_aligned: # the inputs are already aligned
88
  img = cv2.resize(img, (512, 512))
89
  self.face_helper.cropped_faces = [img]
90
  else:
91
  self.face_helper.read_image(img)
92
  # get face landmarks for each face
93
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
94
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
95
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
96
  # align and warp each face
97
  self.face_helper.align_warp_face()
98
 
 
115
  self.face_helper.add_restored_face(restored_face)
116
 
117
  if not has_aligned and paste_back:
118
+ # upsample the background
119
  if self.bg_upsampler is not None:
120
+ # Now only support RealESRGAN for upsampling background
121
  bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
122
  else:
123
  bg_img = None
 
128
  return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
129
  else:
130
  return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_gfpgan.py CHANGED
@@ -10,20 +10,28 @@ from gfpgan import GFPGANer
10
 
11
 
12
  def main():
 
 
13
  parser = argparse.ArgumentParser()
14
-
15
- parser.add_argument('--upscale', type=int, default=2)
16
- parser.add_argument('--arch', type=str, default='clean')
17
- parser.add_argument('--channel', type=int, default=2)
18
- parser.add_argument('--model_path', type=str, default='GFPGANCleanv1-NoCE-C2.pth')
19
- parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
20
- parser.add_argument('--bg_tile', type=int, default=400)
21
- parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
22
  parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
23
- parser.add_argument('--only_center_face', action='store_true')
24
- parser.add_argument('--aligned', action='store_true')
25
- parser.add_argument('--paste_back', action='store_false')
26
- parser.add_argument('--save_root', type=str, default='results')
 
 
 
 
 
 
27
 
28
  args = parser.parse_args()
29
  if args.test_path.endswith('/'):
@@ -38,10 +46,13 @@ def main():
38
  'If you really want to use it, please modify the corresponding codes.')
39
  bg_upsampler = None
40
  else:
 
41
  from realesrgan import RealESRGANer
 
42
  bg_upsampler = RealESRGANer(
43
  scale=2,
44
  model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
 
45
  tile=args.bg_tile,
46
  tile_pad=10,
47
  pre_pad=0,
@@ -64,31 +75,38 @@ def main():
64
  basename, ext = os.path.splitext(img_name)
65
  input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
66
 
 
67
  cropped_faces, restored_faces, restored_img = restorer.enhance(
68
  input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
69
 
70
  # save faces
71
  for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
72
  # save cropped face
73
- save_crop_path = os.path.join(args.save_root, f'{basename}_{idx:02d}.png')
74
  imwrite(cropped_face, save_crop_path)
75
  # save restored face
76
  if args.suffix is not None:
77
  save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
78
  else:
79
  save_face_name = f'{basename}_{idx:02d}.png'
80
- save_restore_path = os.path.join(args.save_root, save_face_name)
81
  imwrite(restored_face, save_restore_path)
82
- # save cmp image
83
- #cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
84
- #imwrite(cmp_img, os.path.join(args.save_root, f'{basename}_{idx:02d}.png'))
85
 
86
  # save restored img
87
  if restored_img is not None:
 
 
 
 
 
88
  if args.suffix is not None:
89
- save_restore_path = os.path.join(args.save_root, f'{basename}_{args.suffix}{ext}')
 
90
  else:
91
- save_restore_path = os.path.join(args.save_root, img_name)
92
  imwrite(restored_img, save_restore_path)
93
 
94
  print(f'Results are in the [{args.save_root}] folder.')
 
10
 
11
 
12
  def main():
13
+ """Inference demo for GFPGAN.
14
+ """
15
  parser = argparse.ArgumentParser()
16
+ parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
17
+ parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
18
+ parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
19
+ parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
20
+ parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
21
+ parser.add_argument(
22
+ '--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
23
+ parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
24
  parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
25
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
26
+ parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
27
+ parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
28
+ parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
29
+ parser.add_argument(
30
+ '--ext',
31
+ type=str,
32
+ default='auto',
33
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
34
+ args = parser.parse_args()
35
 
36
  args = parser.parse_args()
37
  if args.test_path.endswith('/'):
 
46
  'If you really want to use it, please modify the corresponding codes.')
47
  bg_upsampler = None
48
  else:
49
+ from basicsr.archs.rrdbnet_arch import RRDBNet
50
  from realesrgan import RealESRGANer
51
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
52
  bg_upsampler = RealESRGANer(
53
  scale=2,
54
  model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
55
+ model=model,
56
  tile=args.bg_tile,
57
  tile_pad=10,
58
  pre_pad=0,
 
75
  basename, ext = os.path.splitext(img_name)
76
  input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
77
 
78
+ # restore faces and background if necessary
79
  cropped_faces, restored_faces, restored_img = restorer.enhance(
80
  input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
81
 
82
  # save faces
83
  for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
84
  # save cropped face
85
+ save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
86
  imwrite(cropped_face, save_crop_path)
87
  # save restored face
88
  if args.suffix is not None:
89
  save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
90
  else:
91
  save_face_name = f'{basename}_{idx:02d}.png'
92
+ save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
93
  imwrite(restored_face, save_restore_path)
94
+ # save comparison image
95
+ cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
96
+ imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
97
 
98
  # save restored img
99
  if restored_img is not None:
100
+ if args.ext == 'auto':
101
+ extension = ext[1:]
102
+ else:
103
+ extension = args.ext
104
+
105
  if args.suffix is not None:
106
+ save_restore_path = os.path.join(args.save_root, 'restored_imgs',
107
+ f'{basename}_{args.suffix}.{extension}')
108
  else:
109
+ save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
110
  imwrite(restored_img, save_restore_path)
111
 
112
  print(f'Results are in the [{args.save_root}] folder.')
options/train_gfpgan_v1.yml CHANGED
@@ -1,7 +1,7 @@
1
  # general settings
2
  name: train_GFPGANv1_512
3
  model_type: GFPGANModel
4
- num_gpu: 4
5
  manual_seed: 0
6
 
7
  # dataset and data loader settings
@@ -194,7 +194,7 @@ val:
194
  save_img: true
195
 
196
  metrics:
197
- psnr: # metric name, can be arbitrary
198
  type: calculate_psnr
199
  crop_border: 0
200
  test_y_channel: false
 
1
  # general settings
2
  name: train_GFPGANv1_512
3
  model_type: GFPGANModel
4
+ num_gpu: auto # officially, we use 4 GPUs
5
  manual_seed: 0
6
 
7
  # dataset and data loader settings
 
194
  save_img: true
195
 
196
  metrics:
197
+ psnr: # metric name
198
  type: calculate_psnr
199
  crop_border: 0
200
  test_y_channel: false
options/train_gfpgan_v1_simple.yml CHANGED
@@ -1,7 +1,7 @@
1
  # general settings
2
  name: train_GFPGANv1_512_simple
3
  model_type: GFPGANModel
4
- num_gpu: 4
5
  manual_seed: 0
6
 
7
  # dataset and data loader settings
@@ -40,10 +40,6 @@ datasets:
40
  # gray_prob: 0.01
41
  # gt_gray: True
42
 
43
- # crop_components: false
44
- # component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
45
- # eye_enlarge_ratio: 1.4
46
-
47
  # data loader
48
  use_shuffle: true
49
  num_worker_per_gpu: 6
@@ -86,20 +82,6 @@ network_d:
86
  channel_multiplier: 1
87
  resample_kernel: [1, 3, 3, 1]
88
 
89
- # network_d_left_eye:
90
- # type: FacialComponentDiscriminator
91
-
92
- # network_d_right_eye:
93
- # type: FacialComponentDiscriminator
94
-
95
- # network_d_mouth:
96
- # type: FacialComponentDiscriminator
97
-
98
- network_identity:
99
- type: ResNetArcFace
100
- block: IRBlock
101
- layers: [2, 2, 2, 2]
102
- use_se: False
103
 
104
  # path
105
  path:
@@ -107,13 +89,7 @@ path:
107
  param_key_g: params_ema
108
  strict_load_g: ~
109
  pretrain_network_d: ~
110
- # pretrain_network_d_left_eye: ~
111
- # pretrain_network_d_right_eye: ~
112
- # pretrain_network_d_mouth: ~
113
- pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
114
- # resume
115
  resume_state: ~
116
- ignore_resume_networks: ['network_identity']
117
 
118
  # training settings
119
  train:
@@ -173,16 +149,6 @@ train:
173
  loss_weight: !!float 1e-1
174
  # r1 regularization for discriminator
175
  r1_reg_weight: 10
176
- # facial component loss
177
- # gan_component_opt:
178
- # type: GANLoss
179
- # gan_type: vanilla
180
- # real_label_val: 1.0
181
- # fake_label_val: 0.0
182
- # loss_weight: !!float 1
183
- # comp_style_weight: 200
184
- # identity loss
185
- identity_weight: 10
186
 
187
  net_d_iters: 1
188
  net_d_init_iters: 0
@@ -194,7 +160,7 @@ val:
194
  save_img: true
195
 
196
  metrics:
197
- psnr: # metric name, can be arbitrary
198
  type: calculate_psnr
199
  crop_border: 0
200
  test_y_channel: false
 
1
  # general settings
2
  name: train_GFPGANv1_512_simple
3
  model_type: GFPGANModel
4
+ num_gpu: auto # officially, we use 4 GPUs
5
  manual_seed: 0
6
 
7
  # dataset and data loader settings
 
40
  # gray_prob: 0.01
41
  # gt_gray: True
42
 
 
 
 
 
43
  # data loader
44
  use_shuffle: true
45
  num_worker_per_gpu: 6
 
82
  channel_multiplier: 1
83
  resample_kernel: [1, 3, 3, 1]
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # path
87
  path:
 
89
  param_key_g: params_ema
90
  strict_load_g: ~
91
  pretrain_network_d: ~
 
 
 
 
 
92
  resume_state: ~
 
93
 
94
  # training settings
95
  train:
 
149
  loss_weight: !!float 1e-1
150
  # r1 regularization for discriminator
151
  r1_reg_weight: 10
 
 
 
 
 
 
 
 
 
 
152
 
153
  net_d_iters: 1
154
  net_d_init_iters: 0
 
160
  save_img: true
161
 
162
  metrics:
163
+ psnr: # metric name
164
  type: calculate_psnr
165
  crop_border: 0
166
  test_y_channel: false
requirements.txt CHANGED
@@ -1,18 +1,12 @@
1
- torch
2
- numpy
3
- opencv-python-headless
4
- setuptools
5
- Pillow
6
- gradio
7
  torchvision
8
- addict
9
- future
 
 
10
  lmdb
11
  pyyaml
12
- requests
13
- scikit-image
14
- scipy
15
  tb-nightly
16
- tqdm
17
  yapf
18
- psutil
 
1
+ torch>=1.7
2
+ numpy<1.21 # numba requires numpy<1.21,>=1.17
3
+ opencv-python
 
 
 
4
  torchvision
5
+ scipy
6
+ tqdm
7
+ basicsr>=1.3.4.0
8
+ facexlib>=0.2.0.3
9
  lmdb
10
  pyyaml
 
 
 
11
  tb-nightly
 
12
  yapf
 
scripts/parse_landmark.py CHANGED
@@ -1,24 +1,31 @@
1
  import cv2
2
  import json
3
  import numpy as np
 
4
  import torch
5
  from basicsr.utils import FileClient, imfrombytes
6
  from collections import OrderedDict
7
 
 
 
 
 
 
 
 
 
 
8
  print('Load JSON metadata...')
9
- # use the json file in FFHQ dataset
10
- with open('ffhq-dataset-v2.json', 'rb') as f:
11
  json_data = json.load(f, object_pairs_hook=OrderedDict)
12
 
13
  print('Open LMDB file...')
14
  # read ffhq images
15
- file_client = FileClient('lmdb', db_paths='datasets/ffhq/ffhq_512.lmdb')
16
- with open('datasets/ffhq/ffhq_512.lmdb/meta_info.txt') as fin:
17
  paths = [line.split('.')[0] for line in fin]
18
 
19
- save_img = False
20
- scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
21
- enlarge_ratio = 1.4 # only for eyes
22
  save_dict = {}
23
 
24
  for item_idx, item in enumerate(json_data.values()):
@@ -34,6 +41,7 @@ for item_idx, item in enumerate(json_data.values()):
34
  img_bytes = file_client.get(paths[item_idx])
35
  img = imfrombytes(img_bytes, float32=True)
36
 
 
37
  map_left_eye = list(range(36, 42))
38
  map_right_eye = list(range(42, 48))
39
  map_mouth = list(range(48, 68))
@@ -74,4 +82,4 @@ for item_idx, item in enumerate(json_data.values()):
74
  save_dict[f'{item_idx:08d}'] = item_dict
75
 
76
  print('Save...')
77
- torch.save(save_dict, './FFHQ_eye_mouth_landmarks_512.pth')
 
1
  import cv2
2
  import json
3
  import numpy as np
4
+ import os
5
  import torch
6
  from basicsr.utils import FileClient, imfrombytes
7
  from collections import OrderedDict
8
 
9
+ # ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
10
+ # Configurations
11
+ save_img = False
12
+ scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
13
+ enlarge_ratio = 1.4 # only for eyes
14
+ json_path = 'ffhq-dataset-v2.json'
15
+ face_path = 'datasets/ffhq/ffhq_512.lmdb'
16
+ save_path = './FFHQ_eye_mouth_landmarks_512.pth'
17
+
18
  print('Load JSON metadata...')
19
+ # use the official json file in FFHQ dataset
20
+ with open(json_path, 'rb') as f:
21
  json_data = json.load(f, object_pairs_hook=OrderedDict)
22
 
23
  print('Open LMDB file...')
24
  # read ffhq images
25
+ file_client = FileClient('lmdb', db_paths=face_path)
26
+ with open(os.path.join(face_path, 'meta_info.txt')) as fin:
27
  paths = [line.split('.')[0] for line in fin]
28
 
 
 
 
29
  save_dict = {}
30
 
31
  for item_idx, item in enumerate(json_data.values()):
 
41
  img_bytes = file_client.get(paths[item_idx])
42
  img = imfrombytes(img_bytes, float32=True)
43
 
44
+ # get landmarks for each component
45
  map_left_eye = list(range(36, 42))
46
  map_right_eye = list(range(42, 48))
47
  map_mouth = list(range(48, 68))
 
82
  save_dict[f'{item_idx:08d}'] = item_dict
83
 
84
  print('Save...')
85
+ torch.save(save_dict, save_path)
setup.cfg CHANGED
@@ -17,6 +17,17 @@ line_length = 120
17
  multi_line_output = 0
18
  known_standard_library = pkg_resources,setuptools
19
  known_first_party = gfpgan
20
- known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
21
  no_lines_before = STDLIB,LOCALFOLDER
22
  default_section = THIRDPARTY
 
 
 
 
 
 
 
 
 
 
 
 
17
  multi_line_output = 0
18
  known_standard_library = pkg_resources,setuptools
19
  known_first_party = gfpgan
20
+ known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
21
  no_lines_before = STDLIB,LOCALFOLDER
22
  default_section = THIRDPARTY
23
+
24
+ [codespell]
25
+ skip = .git,./docs/build
26
+ count =
27
+ quiet-level = 3
28
+
29
+ [aliases]
30
+ test=pytest
31
+
32
+ [tool:pytest]
33
+ addopts=tests/
setup.py CHANGED
@@ -43,12 +43,6 @@ def get_git_hash():
43
  def get_hash():
44
  if os.path.exists('.git'):
45
  sha = get_git_hash()[:7]
46
- elif os.path.exists(version_file):
47
- try:
48
- from facexlib.version import __version__
49
- sha = __version__.split('+')[-1]
50
- except ImportError:
51
- raise ImportError('Unable to get git version')
52
  else:
53
  sha = 'unknown'
54
 
 
43
  def get_hash():
44
  if os.path.exists('.git'):
45
  sha = get_git_hash()[:7]
 
 
 
 
 
 
46
  else:
47
  sha = 'unknown'
48
 
tests/data/ffhq_gt.lmdb/data.mdb ADDED
Binary file (455 kB). View file
 
experiments/.DS_Store → tests/data/ffhq_gt.lmdb/lock.mdb RENAMED
Binary files a/experiments/.DS_Store and b/tests/data/ffhq_gt.lmdb/lock.mdb differ
 
tests/data/ffhq_gt.lmdb/meta_info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 00000000.png (512,512,3) 1
tests/data/gt/00000000.png ADDED
tests/data/test_eye_mouth_landmarks.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:131583fca2cc346652f8754eb3c5a0bdeda808686039ff10ead7a26254b72358
3
+ size 943
tests/data/test_ffhq_degradation_dataset.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: UnitTest
2
+ type: FFHQDegradationDataset
3
+ dataroot_gt: tests/data/gt
4
+ io_backend:
5
+ type: disk
6
+
7
+ use_hflip: true
8
+ mean: [0.5, 0.5, 0.5]
9
+ std: [0.5, 0.5, 0.5]
10
+ out_size: 512
11
+
12
+ blur_kernel_size: 41
13
+ kernel_list: ['iso', 'aniso']
14
+ kernel_prob: [0.5, 0.5]
15
+ blur_sigma: [0.1, 10]
16
+ downsample_range: [0.8, 8]
17
+ noise_range: [0, 20]
18
+ jpeg_range: [60, 100]
19
+
20
+ # color jitter and gray
21
+ color_jitter_prob: 1
22
+ color_jitter_shift: 20
23
+ color_jitter_pt_prob: 1
24
+ gray_prob: 1
tests/data/test_gfpgan_model.yml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_gpu: 1
2
+ manual_seed: 0
3
+ is_train: True
4
+ dist: False
5
+
6
+ # network structures
7
+ network_g:
8
+ type: GFPGANv1
9
+ out_size: 512
10
+ num_style_feat: 512
11
+ channel_multiplier: 1
12
+ resample_kernel: [1, 3, 3, 1]
13
+ decoder_load_path: ~
14
+ fix_decoder: true
15
+ num_mlp: 8
16
+ lr_mlp: 0.01
17
+ input_is_latent: true
18
+ different_w: true
19
+ narrow: 0.5
20
+ sft_half: true
21
+
22
+ network_d:
23
+ type: StyleGAN2Discriminator
24
+ out_size: 512
25
+ channel_multiplier: 1
26
+ resample_kernel: [1, 3, 3, 1]
27
+
28
+ network_d_left_eye:
29
+ type: FacialComponentDiscriminator
30
+
31
+ network_d_right_eye:
32
+ type: FacialComponentDiscriminator
33
+
34
+ network_d_mouth:
35
+ type: FacialComponentDiscriminator
36
+
37
+ network_identity:
38
+ type: ResNetArcFace
39
+ block: IRBlock
40
+ layers: [2, 2, 2, 2]
41
+ use_se: False
42
+
43
+ # path
44
+ path:
45
+ pretrain_network_g: ~
46
+ param_key_g: params_ema
47
+ strict_load_g: ~
48
+ pretrain_network_d: ~
49
+ pretrain_network_d_left_eye: ~
50
+ pretrain_network_d_right_eye: ~
51
+ pretrain_network_d_mouth: ~
52
+ pretrain_network_identity: ~
53
+ # resume
54
+ resume_state: ~
55
+ ignore_resume_networks: ['network_identity']
56
+
57
+ # training settings
58
+ train:
59
+ optim_g:
60
+ type: Adam
61
+ lr: !!float 2e-3
62
+ optim_d:
63
+ type: Adam
64
+ lr: !!float 2e-3
65
+ optim_component:
66
+ type: Adam
67
+ lr: !!float 2e-3
68
+
69
+ scheduler:
70
+ type: MultiStepLR
71
+ milestones: [600000, 700000]
72
+ gamma: 0.5
73
+
74
+ total_iter: 800000
75
+ warmup_iter: -1 # no warm up
76
+
77
+ # losses
78
+ # pixel loss
79
+ pixel_opt:
80
+ type: L1Loss
81
+ loss_weight: !!float 1e-1
82
+ reduction: mean
83
+ # L1 loss used in pyramid loss, component style loss and identity loss
84
+ L1_opt:
85
+ type: L1Loss
86
+ loss_weight: 1
87
+ reduction: mean
88
+
89
+ # image pyramid loss
90
+ pyramid_loss_weight: 1
91
+ remove_pyramid_loss: 50000
92
+ # perceptual loss (content and style losses)
93
+ perceptual_opt:
94
+ type: PerceptualLoss
95
+ layer_weights:
96
+ # before relu
97
+ 'conv1_2': 0.1
98
+ 'conv2_2': 0.1
99
+ 'conv3_4': 1
100
+ 'conv4_4': 1
101
+ 'conv5_4': 1
102
+ vgg_type: vgg19
103
+ use_input_norm: true
104
+ perceptual_weight: !!float 1
105
+ style_weight: 50
106
+ range_norm: true
107
+ criterion: l1
108
+ # gan loss
109
+ gan_opt:
110
+ type: GANLoss
111
+ gan_type: wgan_softplus
112
+ loss_weight: !!float 1e-1
113
+ # r1 regularization for discriminator
114
+ r1_reg_weight: 10
115
+ # facial component loss
116
+ gan_component_opt:
117
+ type: GANLoss
118
+ gan_type: vanilla
119
+ real_label_val: 1.0
120
+ fake_label_val: 0.0
121
+ loss_weight: !!float 1
122
+ comp_style_weight: 200
123
+ # identity loss
124
+ identity_weight: 10
125
+
126
+ net_d_iters: 1
127
+ net_d_init_iters: 0
128
+ net_d_reg_every: 1
129
+
130
+ # validation settings
131
+ val:
132
+ val_freq: !!float 5e3
133
+ save_img: True
134
+ use_pbar: True
135
+
136
+ metrics:
137
+ psnr: # metric name
138
+ type: calculate_psnr
139
+ crop_border: 0
140
+ test_y_channel: false
tests/test_arcface_arch.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
4
+
5
+
6
+ def test_resnetarcface():
7
+ """Test arch: ResNetArcFace."""
8
+
9
+ # model init and forward (gpu)
10
+ if torch.cuda.is_available():
11
+ net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
12
+ img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
13
+ output = net(img)
14
+ assert output.shape == (1, 512)
15
+
16
+ # -------------------- without SE block ----------------------- #
17
+ net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
18
+ output = net(img)
19
+ assert output.shape == (1, 512)
20
+
21
+
22
+ def test_basicblock():
23
+ """Test the BasicBlock in arcface_arch"""
24
+ block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
25
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
26
+ output = block(img)
27
+ assert output.shape == (1, 3, 12, 12)
28
+
29
+ # ----------------- use the downsmaple module--------------- #
30
+ downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
31
+ block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
32
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
33
+ output = block(img)
34
+ assert output.shape == (1, 3, 6, 6)
35
+
36
+
37
+ def test_bottleneck():
38
+ """Test the Bottleneck in arcface_arch"""
39
+ block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
40
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
41
+ output = block(img)
42
+ assert output.shape == (1, 4, 12, 12)
43
+
44
+ # ----------------- use the downsmaple module--------------- #
45
+ downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
46
+ block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
47
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
48
+ output = block(img)
49
+ assert output.shape == (1, 4, 6, 6)
tests/test_ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import yaml
3
+
4
+ from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
5
+
6
+
7
+ def test_ffhq_degradation_dataset():
8
+
9
+ with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
10
+ opt = yaml.load(f, Loader=yaml.FullLoader)
11
+
12
+ dataset = FFHQDegradationDataset(opt)
13
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
14
+ assert len(dataset) == 1 # whether to read correct meta info
15
+ assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
16
+ assert dataset.color_jitter_prob == 1
17
+
18
+ # test __getitem__
19
+ result = dataset.__getitem__(0)
20
+ # check returned keys
21
+ expected_keys = ['gt', 'lq', 'gt_path']
22
+ assert set(expected_keys).issubset(set(result.keys()))
23
+ # check shape and contents
24
+ assert result['gt'].shape == (3, 512, 512)
25
+ assert result['lq'].shape == (3, 512, 512)
26
+ assert result['gt_path'] == 'tests/data/gt/00000000.png'
27
+
28
+ # ------------------ test with probability = 0 -------------------- #
29
+ opt['color_jitter_prob'] = 0
30
+ opt['color_jitter_pt_prob'] = 0
31
+ opt['gray_prob'] = 0
32
+ opt['io_backend'] = dict(type='disk')
33
+ dataset = FFHQDegradationDataset(opt)
34
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
35
+ assert len(dataset) == 1 # whether to read correct meta info
36
+ assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
37
+ assert dataset.color_jitter_prob == 0
38
+
39
+ # test __getitem__
40
+ result = dataset.__getitem__(0)
41
+ # check returned keys
42
+ expected_keys = ['gt', 'lq', 'gt_path']
43
+ assert set(expected_keys).issubset(set(result.keys()))
44
+ # check shape and contents
45
+ assert result['gt'].shape == (3, 512, 512)
46
+ assert result['lq'].shape == (3, 512, 512)
47
+ assert result['gt_path'] == 'tests/data/gt/00000000.png'
48
+
49
+ # ------------------ test lmdb backend -------------------- #
50
+ opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
51
+ opt['io_backend'] = dict(type='lmdb')
52
+
53
+ dataset = FFHQDegradationDataset(opt)
54
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
55
+ assert len(dataset) == 1 # whether to read correct meta info
56
+ assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
57
+ assert dataset.color_jitter_prob == 0
58
+
59
+ # test __getitem__
60
+ result = dataset.__getitem__(0)
61
+ # check returned keys
62
+ expected_keys = ['gt', 'lq', 'gt_path']
63
+ assert set(expected_keys).issubset(set(result.keys()))
64
+ # check shape and contents
65
+ assert result['gt'].shape == (3, 512, 512)
66
+ assert result['lq'].shape == (3, 512, 512)
67
+ assert result['gt_path'] == '00000000'
68
+
69
+ # ------------------ test with crop_components -------------------- #
70
+ opt['crop_components'] = True
71
+ opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
72
+ opt['eye_enlarge_ratio'] = 1.4
73
+ opt['gt_gray'] = True
74
+ opt['io_backend'] = dict(type='lmdb')
75
+
76
+ dataset = FFHQDegradationDataset(opt)
77
+ assert dataset.crop_components is True
78
+
79
+ # test __getitem__
80
+ result = dataset.__getitem__(0)
81
+ # check returned keys
82
+ expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
83
+ assert set(expected_keys).issubset(set(result.keys()))
84
+ # check shape and contents
85
+ assert result['gt'].shape == (3, 512, 512)
86
+ assert result['lq'].shape == (3, 512, 512)
87
+ assert result['gt_path'] == '00000000'
88
+ assert result['loc_left_eye'].shape == (4, )
89
+ assert result['loc_right_eye'].shape == (4, )
90
+ assert result['loc_mouth'].shape == (4, )
91
+
92
+ # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
93
+ with pytest.raises(ValueError):
94
+ opt['dataroot_gt'] = 'tests/data/gt'
95
+ opt['io_backend'] = dict(type='lmdb')
96
+ dataset = FFHQDegradationDataset(opt)
tests/test_gfpgan_arch.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
4
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
5
+
6
+
7
+ def test_stylegan2generatorsft():
8
+ """Test arch: StyleGAN2GeneratorSFT."""
9
+
10
+ # model init and forward (gpu)
11
+ if torch.cuda.is_available():
12
+ net = StyleGAN2GeneratorSFT(
13
+ out_size=32,
14
+ num_style_feat=512,
15
+ num_mlp=8,
16
+ channel_multiplier=1,
17
+ resample_kernel=(1, 3, 3, 1),
18
+ lr_mlp=0.01,
19
+ narrow=1,
20
+ sft_half=False).cuda().eval()
21
+ style = torch.rand((1, 512), dtype=torch.float32).cuda()
22
+ condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
23
+ condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
24
+ condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
25
+ conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
26
+ output = net([style], conditions)
27
+ assert output[0].shape == (1, 3, 32, 32)
28
+ assert output[1] is None
29
+
30
+ # -------------------- with return_latents ----------------------- #
31
+ output = net([style], conditions, return_latents=True)
32
+ assert output[0].shape == (1, 3, 32, 32)
33
+ assert len(output[1]) == 1
34
+ # check latent
35
+ assert output[1][0].shape == (8, 512)
36
+
37
+ # -------------------- with randomize_noise = False ----------------------- #
38
+ output = net([style], conditions, randomize_noise=False)
39
+ assert output[0].shape == (1, 3, 32, 32)
40
+ assert output[1] is None
41
+
42
+ # -------------------- with truncation = 0.5 and mixing----------------------- #
43
+ output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
44
+ assert output[0].shape == (1, 3, 32, 32)
45
+ assert output[1] is None
46
+
47
+
48
+ def test_gfpganv1():
49
+ """Test arch: GFPGANv1."""
50
+
51
+ # model init and forward (gpu)
52
+ if torch.cuda.is_available():
53
+ net = GFPGANv1(
54
+ out_size=32,
55
+ num_style_feat=512,
56
+ channel_multiplier=1,
57
+ resample_kernel=(1, 3, 3, 1),
58
+ decoder_load_path=None,
59
+ fix_decoder=True,
60
+ # for stylegan decoder
61
+ num_mlp=8,
62
+ lr_mlp=0.01,
63
+ input_is_latent=False,
64
+ different_w=False,
65
+ narrow=1,
66
+ sft_half=True).cuda().eval()
67
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
68
+ output = net(img)
69
+ assert output[0].shape == (1, 3, 32, 32)
70
+ assert len(output[1]) == 3
71
+ # check out_rgbs for intermediate loss
72
+ assert output[1][0].shape == (1, 3, 8, 8)
73
+ assert output[1][1].shape == (1, 3, 16, 16)
74
+ assert output[1][2].shape == (1, 3, 32, 32)
75
+
76
+ # -------------------- with different_w = True ----------------------- #
77
+ net = GFPGANv1(
78
+ out_size=32,
79
+ num_style_feat=512,
80
+ channel_multiplier=1,
81
+ resample_kernel=(1, 3, 3, 1),
82
+ decoder_load_path=None,
83
+ fix_decoder=True,
84
+ # for stylegan decoder
85
+ num_mlp=8,
86
+ lr_mlp=0.01,
87
+ input_is_latent=False,
88
+ different_w=True,
89
+ narrow=1,
90
+ sft_half=True).cuda().eval()
91
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
92
+ output = net(img)
93
+ assert output[0].shape == (1, 3, 32, 32)
94
+ assert len(output[1]) == 3
95
+ # check out_rgbs for intermediate loss
96
+ assert output[1][0].shape == (1, 3, 8, 8)
97
+ assert output[1][1].shape == (1, 3, 16, 16)
98
+ assert output[1][2].shape == (1, 3, 32, 32)
99
+
100
+
101
+ def test_facialcomponentdiscriminator():
102
+ """Test arch: FacialComponentDiscriminator."""
103
+
104
+ # model init and forward (gpu)
105
+ if torch.cuda.is_available():
106
+ net = FacialComponentDiscriminator().cuda().eval()
107
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
108
+ output = net(img)
109
+ assert len(output) == 2
110
+ assert output[0].shape == (1, 1, 8, 8)
111
+ assert output[1] is None
112
+
113
+ # -------------------- return intermediate features ----------------------- #
114
+ output = net(img, return_feats=True)
115
+ assert len(output) == 2
116
+ assert output[0].shape == (1, 1, 8, 8)
117
+ assert len(output[1]) == 2
118
+ assert output[1][0].shape == (1, 128, 16, 16)
119
+ assert output[1][1].shape == (1, 256, 8, 8)
120
+
121
+
122
+ def test_stylegan2generatorcsft():
123
+ """Test arch: StyleGAN2GeneratorCSFT."""
124
+
125
+ # model init and forward (gpu)
126
+ if torch.cuda.is_available():
127
+ net = StyleGAN2GeneratorCSFT(
128
+ out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
129
+ style = torch.rand((1, 512), dtype=torch.float32).cuda()
130
+ condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
131
+ condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
132
+ condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
133
+ conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
134
+ output = net([style], conditions)
135
+ assert output[0].shape == (1, 3, 32, 32)
136
+ assert output[1] is None
137
+
138
+ # -------------------- with return_latents ----------------------- #
139
+ output = net([style], conditions, return_latents=True)
140
+ assert output[0].shape == (1, 3, 32, 32)
141
+ assert len(output[1]) == 1
142
+ # check latent
143
+ assert output[1][0].shape == (8, 512)
144
+
145
+ # -------------------- with randomize_noise = False ----------------------- #
146
+ output = net([style], conditions, randomize_noise=False)
147
+ assert output[0].shape == (1, 3, 32, 32)
148
+ assert output[1] is None
149
+
150
+ # -------------------- with truncation = 0.5 and mixing----------------------- #
151
+ output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
152
+ assert output[0].shape == (1, 3, 32, 32)
153
+ assert output[1] is None
154
+
155
+
156
+ def test_gfpganv1clean():
157
+ """Test arch: GFPGANv1Clean."""
158
+
159
+ # model init and forward (gpu)
160
+ if torch.cuda.is_available():
161
+ net = GFPGANv1Clean(
162
+ out_size=32,
163
+ num_style_feat=512,
164
+ channel_multiplier=1,
165
+ decoder_load_path=None,
166
+ fix_decoder=True,
167
+ # for stylegan decoder
168
+ num_mlp=8,
169
+ input_is_latent=False,
170
+ different_w=False,
171
+ narrow=1,
172
+ sft_half=True).cuda().eval()
173
+
174
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
175
+ output = net(img)
176
+ assert output[0].shape == (1, 3, 32, 32)
177
+ assert len(output[1]) == 3
178
+ # check out_rgbs for intermediate loss
179
+ assert output[1][0].shape == (1, 3, 8, 8)
180
+ assert output[1][1].shape == (1, 3, 16, 16)
181
+ assert output[1][2].shape == (1, 3, 32, 32)
182
+
183
+ # -------------------- with different_w = True ----------------------- #
184
+ net = GFPGANv1Clean(
185
+ out_size=32,
186
+ num_style_feat=512,
187
+ channel_multiplier=1,
188
+ decoder_load_path=None,
189
+ fix_decoder=True,
190
+ # for stylegan decoder
191
+ num_mlp=8,
192
+ input_is_latent=False,
193
+ different_w=True,
194
+ narrow=1,
195
+ sft_half=True).cuda().eval()
196
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
197
+ output = net(img)
198
+ assert output[0].shape == (1, 3, 32, 32)
199
+ assert len(output[1]) == 3
200
+ # check out_rgbs for intermediate loss
201
+ assert output[1][0].shape == (1, 3, 8, 8)
202
+ assert output[1][1].shape == (1, 3, 16, 16)
203
+ assert output[1][2].shape == (1, 3, 32, 32)
tests/test_gfpgan_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import torch
3
+ import yaml
4
+ from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
5
+ from basicsr.data.paired_image_dataset import PairedImageDataset
6
+ from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
7
+
8
+ from gfpgan.archs.arcface_arch import ResNetArcFace
9
+ from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
10
+ from gfpgan.models.gfpgan_model import GFPGANModel
11
+
12
+
13
+ def test_gfpgan_model():
14
+ with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
15
+ opt = yaml.load(f, Loader=yaml.FullLoader)
16
+
17
+ # build model
18
+ model = GFPGANModel(opt)
19
+ # test attributes
20
+ assert model.__class__.__name__ == 'GFPGANModel'
21
+ assert isinstance(model.net_g, GFPGANv1) # generator
22
+ assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
23
+ # facial component discriminators
24
+ assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
25
+ assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
26
+ assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
27
+ # identity network
28
+ assert isinstance(model.network_identity, ResNetArcFace)
29
+ # losses
30
+ assert isinstance(model.cri_pix, L1Loss)
31
+ assert isinstance(model.cri_perceptual, PerceptualLoss)
32
+ assert isinstance(model.cri_gan, GANLoss)
33
+ assert isinstance(model.cri_l1, L1Loss)
34
+ # optimizer
35
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
36
+ assert isinstance(model.optimizers[1], torch.optim.Adam)
37
+
38
+ # prepare data
39
+ gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
40
+ lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
41
+ loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
42
+ loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
43
+ loc_mouth = torch.rand((1, 4), dtype=torch.float32)
44
+ data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
45
+ model.feed_data(data)
46
+ # check data shape
47
+ assert model.lq.shape == (1, 3, 512, 512)
48
+ assert model.gt.shape == (1, 3, 512, 512)
49
+ assert model.loc_left_eyes.shape == (1, 4)
50
+ assert model.loc_right_eyes.shape == (1, 4)
51
+ assert model.loc_mouths.shape == (1, 4)
52
+
53
+ # ----------------- test optimize_parameters -------------------- #
54
+ model.feed_data(data)
55
+ model.optimize_parameters(1)
56
+ assert model.output.shape == (1, 3, 512, 512)
57
+ assert isinstance(model.log_dict, dict)
58
+ # check returned keys
59
+ expected_keys = [
60
+ 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
61
+ 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
62
+ 'l_d_right_eye', 'l_d_mouth'
63
+ ]
64
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
65
+
66
+ # ----------------- remove pyramid_loss_weight-------------------- #
67
+ model.feed_data(data)
68
+ model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
69
+ assert model.output.shape == (1, 3, 512, 512)
70
+ assert isinstance(model.log_dict, dict)
71
+ # check returned keys
72
+ expected_keys = [
73
+ 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
74
+ 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
75
+ 'l_d_right_eye', 'l_d_mouth'
76
+ ]
77
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
78
+
79
+ # ----------------- test save -------------------- #
80
+ with tempfile.TemporaryDirectory() as tmpdir:
81
+ model.opt['path']['models'] = tmpdir
82
+ model.opt['path']['training_states'] = tmpdir
83
+ model.save(0, 1)
84
+
85
+ # ----------------- test the test function -------------------- #
86
+ model.test()
87
+ assert model.output.shape == (1, 3, 512, 512)
88
+ # delete net_g_ema
89
+ model.__delattr__('net_g_ema')
90
+ model.test()
91
+ assert model.output.shape == (1, 3, 512, 512)
92
+ assert model.net_g.training is True # should back to training mode after testing
93
+
94
+ # ----------------- test nondist_validation -------------------- #
95
+ # construct dataloader
96
+ dataset_opt = dict(
97
+ name='Demo',
98
+ dataroot_gt='tests/data/gt',
99
+ dataroot_lq='tests/data/gt',
100
+ io_backend=dict(type='disk'),
101
+ scale=4,
102
+ phase='val')
103
+ dataset = PairedImageDataset(dataset_opt)
104
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
105
+ assert model.is_train is True
106
+ with tempfile.TemporaryDirectory() as tmpdir:
107
+ model.opt['path']['visualization'] = tmpdir
108
+ model.nondist_validation(dataloader, 1, None, save_img=True)
109
+ assert model.is_train is True
110
+ # check metric_results
111
+ assert 'psnr' in model.metric_results
112
+ assert isinstance(model.metric_results['psnr'], float)
113
+
114
+ # validation
115
+ with tempfile.TemporaryDirectory() as tmpdir:
116
+ model.opt['is_train'] = False
117
+ model.opt['val']['suffix'] = 'test'
118
+ model.opt['path']['visualization'] = tmpdir
119
+ model.opt['val']['pbar'] = True
120
+ model.nondist_validation(dataloader, 1, None, save_img=True)
121
+ # check metric_results
122
+ assert 'psnr' in model.metric_results
123
+ assert isinstance(model.metric_results['psnr'], float)
124
+
125
+ # if opt['val']['suffix'] is None
126
+ model.opt['val']['suffix'] = None
127
+ model.opt['name'] = 'demo'
128
+ model.opt['path']['visualization'] = tmpdir
129
+ model.nondist_validation(dataloader, 1, None, save_img=True)
130
+ # check metric_results
131
+ assert 'psnr' in model.metric_results
132
+ assert isinstance(model.metric_results['psnr'], float)
tests/test_stylegan2_clean_arch.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
4
+
5
+
6
+ def test_stylegan2generatorclean():
7
+ """Test arch: StyleGAN2GeneratorClean."""
8
+
9
+ # model init and forward (gpu)
10
+ if torch.cuda.is_available():
11
+ net = StyleGAN2GeneratorClean(
12
+ out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
13
+ style = torch.rand((1, 512), dtype=torch.float32).cuda()
14
+ output = net([style], input_is_latent=False)
15
+ assert output[0].shape == (1, 3, 32, 32)
16
+ assert output[1] is None
17
+
18
+ # -------------------- with return_latents ----------------------- #
19
+ output = net([style], input_is_latent=True, return_latents=True)
20
+ assert output[0].shape == (1, 3, 32, 32)
21
+ assert len(output[1]) == 1
22
+ # check latent
23
+ assert output[1][0].shape == (8, 512)
24
+
25
+ # -------------------- with randomize_noise = False ----------------------- #
26
+ output = net([style], randomize_noise=False)
27
+ assert output[0].shape == (1, 3, 32, 32)
28
+ assert output[1] is None
29
+
30
+ # -------------------- with truncation = 0.5 and mixing----------------------- #
31
+ output = net([style, style], truncation=0.5, truncation_latent=style)
32
+ assert output[0].shape == (1, 3, 32, 32)
33
+ assert output[1] is None
34
+
35
+ # ------------------ test make_noise ----------------------- #
36
+ out = net.make_noise()
37
+ assert len(out) == 7
38
+ assert out[0].shape == (1, 1, 4, 4)
39
+ assert out[1].shape == (1, 1, 8, 8)
40
+ assert out[2].shape == (1, 1, 8, 8)
41
+ assert out[3].shape == (1, 1, 16, 16)
42
+ assert out[4].shape == (1, 1, 16, 16)
43
+ assert out[5].shape == (1, 1, 32, 32)
44
+ assert out[6].shape == (1, 1, 32, 32)
45
+
46
+ # ------------------ test get_latent ----------------------- #
47
+ out = net.get_latent(style)
48
+ assert out.shape == (1, 512)
49
+
50
+ # ------------------ test mean_latent ----------------------- #
51
+ out = net.mean_latent(2)
52
+ assert out.shape == (1, 512)
tests/test_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
3
+
4
+ from gfpgan.archs.gfpganv1_arch import GFPGANv1
5
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
6
+ from gfpgan.utils import GFPGANer
7
+
8
+
9
+ def test_gfpganer():
10
+ # initialize with the clean model
11
+ restorer = GFPGANer(
12
+ model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
13
+ upscale=2,
14
+ arch='clean',
15
+ channel_multiplier=2,
16
+ bg_upsampler=None)
17
+ # test attribute
18
+ assert isinstance(restorer.gfpgan, GFPGANv1Clean)
19
+ assert isinstance(restorer.face_helper, FaceRestoreHelper)
20
+
21
+ # initialize with the original model
22
+ restorer = GFPGANer(
23
+ model_path='experiments/pretrained_models/GFPGANv1.pth',
24
+ upscale=2,
25
+ arch='original',
26
+ channel_multiplier=1,
27
+ bg_upsampler=None)
28
+ # test attribute
29
+ assert isinstance(restorer.gfpgan, GFPGANv1)
30
+ assert isinstance(restorer.face_helper, FaceRestoreHelper)
31
+
32
+ # ------------------ test enhance ---------------- #
33
+ img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
34
+ result = restorer.enhance(img, has_aligned=False, paste_back=True)
35
+ assert result[0][0].shape == (512, 512, 3)
36
+ assert result[1][0].shape == (512, 512, 3)
37
+ assert result[2].shape == (1024, 1024, 3)
38
+
39
+ # with has_aligned=True
40
+ result = restorer.enhance(img, has_aligned=True, paste_back=False)
41
+ assert result[0][0].shape == (512, 512, 3)
42
+ assert result[1][0].shape == (512, 512, 3)
43
+ assert result[2] is None