NimaBoscarino commited on
Commit
e048e19
1 Parent(s): 60dd3db

Remove model code, only keep model checkpoint

Browse files
README.md CHANGED
@@ -2,6 +2,8 @@
2
 
3
  [**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan)    [English](README.md) **|** [简体中文](README_CN.md)
4
 
 
 
5
  GFPGAN is a blind face restoration algorithm towards real-world face images.
6
 
7
  <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
 
2
 
3
  [**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan) &emsp;&emsp; [English](README.md) **|** [简体中文](README_CN.md)
4
 
5
+ GitHub: https://github.com/TencentARC/GFPGAN
6
+
7
  GFPGAN is a blind face restoration algorithm towards real-world face images.
8
 
9
  <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
archs/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- import importlib
2
- from os import path as osp
3
-
4
- from basicsr.utils import scandir
5
-
6
- # automatically scan and import arch modules for registry
7
- # scan all the files under the 'archs' folder and collect files ending with
8
- # '_arch.py'
9
- arch_folder = osp.dirname(osp.abspath(__file__))
10
- arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
11
- # import all the arch modules
12
- _arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
 
 
 
 
 
 
 
 
 
 
 
 
 
archs/arcface_arch.py DELETED
@@ -1,198 +0,0 @@
1
- import torch.nn as nn
2
-
3
- from basicsr.utils.registry import ARCH_REGISTRY
4
-
5
-
6
- def conv3x3(in_planes, out_planes, stride=1):
7
- """3x3 convolution with padding"""
8
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9
-
10
-
11
- class BasicBlock(nn.Module):
12
- expansion = 1
13
-
14
- def __init__(self, inplanes, planes, stride=1, downsample=None):
15
- super(BasicBlock, self).__init__()
16
- self.conv1 = conv3x3(inplanes, planes, stride)
17
- self.bn1 = nn.BatchNorm2d(planes)
18
- self.relu = nn.ReLU(inplace=True)
19
- self.conv2 = conv3x3(planes, planes)
20
- self.bn2 = nn.BatchNorm2d(planes)
21
- self.downsample = downsample
22
- self.stride = stride
23
-
24
- def forward(self, x):
25
- residual = x
26
-
27
- out = self.conv1(x)
28
- out = self.bn1(out)
29
- out = self.relu(out)
30
-
31
- out = self.conv2(out)
32
- out = self.bn2(out)
33
-
34
- if self.downsample is not None:
35
- residual = self.downsample(x)
36
-
37
- out += residual
38
- out = self.relu(out)
39
-
40
- return out
41
-
42
-
43
- class IRBlock(nn.Module):
44
- expansion = 1
45
-
46
- def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
47
- super(IRBlock, self).__init__()
48
- self.bn0 = nn.BatchNorm2d(inplanes)
49
- self.conv1 = conv3x3(inplanes, inplanes)
50
- self.bn1 = nn.BatchNorm2d(inplanes)
51
- self.prelu = nn.PReLU()
52
- self.conv2 = conv3x3(inplanes, planes, stride)
53
- self.bn2 = nn.BatchNorm2d(planes)
54
- self.downsample = downsample
55
- self.stride = stride
56
- self.use_se = use_se
57
- if self.use_se:
58
- self.se = SEBlock(planes)
59
-
60
- def forward(self, x):
61
- residual = x
62
- out = self.bn0(x)
63
- out = self.conv1(out)
64
- out = self.bn1(out)
65
- out = self.prelu(out)
66
-
67
- out = self.conv2(out)
68
- out = self.bn2(out)
69
- if self.use_se:
70
- out = self.se(out)
71
-
72
- if self.downsample is not None:
73
- residual = self.downsample(x)
74
-
75
- out += residual
76
- out = self.prelu(out)
77
-
78
- return out
79
-
80
-
81
- class Bottleneck(nn.Module):
82
- expansion = 4
83
-
84
- def __init__(self, inplanes, planes, stride=1, downsample=None):
85
- super(Bottleneck, self).__init__()
86
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
87
- self.bn1 = nn.BatchNorm2d(planes)
88
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
89
- self.bn2 = nn.BatchNorm2d(planes)
90
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
91
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
92
- self.relu = nn.ReLU(inplace=True)
93
- self.downsample = downsample
94
- self.stride = stride
95
-
96
- def forward(self, x):
97
- residual = x
98
-
99
- out = self.conv1(x)
100
- out = self.bn1(out)
101
- out = self.relu(out)
102
-
103
- out = self.conv2(out)
104
- out = self.bn2(out)
105
- out = self.relu(out)
106
-
107
- out = self.conv3(out)
108
- out = self.bn3(out)
109
-
110
- if self.downsample is not None:
111
- residual = self.downsample(x)
112
-
113
- out += residual
114
- out = self.relu(out)
115
-
116
- return out
117
-
118
-
119
- class SEBlock(nn.Module):
120
-
121
- def __init__(self, channel, reduction=16):
122
- super(SEBlock, self).__init__()
123
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
124
- self.fc = nn.Sequential(
125
- nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
126
- nn.Sigmoid())
127
-
128
- def forward(self, x):
129
- b, c, _, _ = x.size()
130
- y = self.avg_pool(x).view(b, c)
131
- y = self.fc(y).view(b, c, 1, 1)
132
- return x * y
133
-
134
-
135
- @ARCH_REGISTRY.register()
136
- class ResNetArcFace(nn.Module):
137
-
138
- def __init__(self, block, layers, use_se=True):
139
- if block == 'IRBlock':
140
- block = IRBlock
141
- self.inplanes = 64
142
- self.use_se = use_se
143
- super(ResNetArcFace, self).__init__()
144
- self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
145
- self.bn1 = nn.BatchNorm2d(64)
146
- self.prelu = nn.PReLU()
147
- self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
148
- self.layer1 = self._make_layer(block, 64, layers[0])
149
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
150
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
151
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
152
- self.bn4 = nn.BatchNorm2d(512)
153
- self.dropout = nn.Dropout()
154
- self.fc5 = nn.Linear(512 * 8 * 8, 512)
155
- self.bn5 = nn.BatchNorm1d(512)
156
-
157
- for m in self.modules():
158
- if isinstance(m, nn.Conv2d):
159
- nn.init.xavier_normal_(m.weight)
160
- elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
161
- nn.init.constant_(m.weight, 1)
162
- nn.init.constant_(m.bias, 0)
163
- elif isinstance(m, nn.Linear):
164
- nn.init.xavier_normal_(m.weight)
165
- nn.init.constant_(m.bias, 0)
166
-
167
- def _make_layer(self, block, planes, blocks, stride=1):
168
- downsample = None
169
- if stride != 1 or self.inplanes != planes * block.expansion:
170
- downsample = nn.Sequential(
171
- nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
172
- nn.BatchNorm2d(planes * block.expansion),
173
- )
174
- layers = []
175
- layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
176
- self.inplanes = planes
177
- for _ in range(1, blocks):
178
- layers.append(block(self.inplanes, planes, use_se=self.use_se))
179
-
180
- return nn.Sequential(*layers)
181
-
182
- def forward(self, x):
183
- x = self.conv1(x)
184
- x = self.bn1(x)
185
- x = self.prelu(x)
186
- x = self.maxpool(x)
187
-
188
- x = self.layer1(x)
189
- x = self.layer2(x)
190
- x = self.layer3(x)
191
- x = self.layer4(x)
192
- x = self.bn4(x)
193
- x = self.dropout(x)
194
- x = x.view(x.size(0), -1)
195
- x = self.fc5(x)
196
- x = self.bn5(x)
197
-
198
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
archs/gfpganv1_arch.py DELETED
@@ -1,418 +0,0 @@
1
- import math
2
- import random
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
8
- StyleGAN2Generator)
9
- from basicsr.ops.fused_act import FusedLeakyReLU
10
- from basicsr.utils.registry import ARCH_REGISTRY
11
-
12
-
13
- class StyleGAN2GeneratorSFT(StyleGAN2Generator):
14
- """StyleGAN2 Generator.
15
-
16
- Args:
17
- out_size (int): The spatial size of outputs.
18
- num_style_feat (int): Channel number of style features. Default: 512.
19
- num_mlp (int): Layer number of MLP style layers. Default: 8.
20
- channel_multiplier (int): Channel multiplier for large networks of
21
- StyleGAN2. Default: 2.
22
- resample_kernel (list[int]): A list indicating the 1D resample kernel
23
- magnitude. A cross production will be applied to extent 1D resample
24
- kenrel to 2D resample kernel. Default: [1, 3, 3, 1].
25
- lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
26
- """
27
-
28
- def __init__(self,
29
- out_size,
30
- num_style_feat=512,
31
- num_mlp=8,
32
- channel_multiplier=2,
33
- resample_kernel=(1, 3, 3, 1),
34
- lr_mlp=0.01,
35
- narrow=1,
36
- sft_half=False):
37
- super(StyleGAN2GeneratorSFT, self).__init__(
38
- out_size,
39
- num_style_feat=num_style_feat,
40
- num_mlp=num_mlp,
41
- channel_multiplier=channel_multiplier,
42
- resample_kernel=resample_kernel,
43
- lr_mlp=lr_mlp,
44
- narrow=narrow)
45
- self.sft_half = sft_half
46
-
47
- def forward(self,
48
- styles,
49
- conditions,
50
- input_is_latent=False,
51
- noise=None,
52
- randomize_noise=True,
53
- truncation=1,
54
- truncation_latent=None,
55
- inject_index=None,
56
- return_latents=False):
57
- """Forward function for StyleGAN2Generator.
58
-
59
- Args:
60
- styles (list[Tensor]): Sample codes of styles.
61
- input_is_latent (bool): Whether input is latent style.
62
- Default: False.
63
- noise (Tensor | None): Input noise or None. Default: None.
64
- randomize_noise (bool): Randomize noise, used when 'noise' is
65
- False. Default: True.
66
- truncation (float): TODO. Default: 1.
67
- truncation_latent (Tensor | None): TODO. Default: None.
68
- inject_index (int | None): The injection index for mixing noise.
69
- Default: None.
70
- return_latents (bool): Whether to return style latents.
71
- Default: False.
72
- """
73
- # style codes -> latents with Style MLP layer
74
- if not input_is_latent:
75
- styles = [self.style_mlp(s) for s in styles]
76
- # noises
77
- if noise is None:
78
- if randomize_noise:
79
- noise = [None] * self.num_layers # for each style conv layer
80
- else: # use the stored noise
81
- noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
82
- # style truncation
83
- if truncation < 1:
84
- style_truncation = []
85
- for style in styles:
86
- style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
87
- styles = style_truncation
88
- # get style latent with injection
89
- if len(styles) == 1:
90
- inject_index = self.num_latent
91
-
92
- if styles[0].ndim < 3:
93
- # repeat latent code for all the layers
94
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
95
- else: # used for encoder with different latent code for each layer
96
- latent = styles[0]
97
- elif len(styles) == 2: # mixing noises
98
- if inject_index is None:
99
- inject_index = random.randint(1, self.num_latent - 1)
100
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
101
- latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
102
- latent = torch.cat([latent1, latent2], 1)
103
-
104
- # main generation
105
- out = self.constant_input(latent.shape[0])
106
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
107
- skip = self.to_rgb1(out, latent[:, 1])
108
-
109
- i = 1
110
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
111
- noise[2::2], self.to_rgbs):
112
- out = conv1(out, latent[:, i], noise=noise1)
113
-
114
- # the conditions may have fewer levels
115
- if i < len(conditions):
116
- # SFT part to combine the conditions
117
- if self.sft_half:
118
- out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
119
- out_sft = out_sft * conditions[i - 1] + conditions[i]
120
- out = torch.cat([out_same, out_sft], dim=1)
121
- else:
122
- out = out * conditions[i - 1] + conditions[i]
123
-
124
- out = conv2(out, latent[:, i + 1], noise=noise2)
125
- skip = to_rgb(out, latent[:, i + 2], skip)
126
- i += 2
127
-
128
- image = skip
129
-
130
- if return_latents:
131
- return image, latent
132
- else:
133
- return image, None
134
-
135
-
136
- class ConvUpLayer(nn.Module):
137
- """Conv Up Layer. Bilinear upsample + Conv.
138
-
139
- Args:
140
- in_channels (int): Channel number of the input.
141
- out_channels (int): Channel number of the output.
142
- kernel_size (int): Size of the convolving kernel.
143
- stride (int): Stride of the convolution. Default: 1
144
- padding (int): Zero-padding added to both sides of the input.
145
- Default: 0.
146
- bias (bool): If ``True``, adds a learnable bias to the output.
147
- Default: ``True``.
148
- bias_init_val (float): Bias initialized value. Default: 0.
149
- activate (bool): Whether use activateion. Default: True.
150
- """
151
-
152
- def __init__(self,
153
- in_channels,
154
- out_channels,
155
- kernel_size,
156
- stride=1,
157
- padding=0,
158
- bias=True,
159
- bias_init_val=0,
160
- activate=True):
161
- super(ConvUpLayer, self).__init__()
162
- self.in_channels = in_channels
163
- self.out_channels = out_channels
164
- self.kernel_size = kernel_size
165
- self.stride = stride
166
- self.padding = padding
167
- self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
168
-
169
- self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
170
-
171
- if bias and not activate:
172
- self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
173
- else:
174
- self.register_parameter('bias', None)
175
-
176
- # activation
177
- if activate:
178
- if bias:
179
- self.activation = FusedLeakyReLU(out_channels)
180
- else:
181
- self.activation = ScaledLeakyReLU(0.2)
182
- else:
183
- self.activation = None
184
-
185
- def forward(self, x):
186
- # bilinear upsample
187
- out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
188
- # conv
189
- out = F.conv2d(
190
- out,
191
- self.weight * self.scale,
192
- bias=self.bias,
193
- stride=self.stride,
194
- padding=self.padding,
195
- )
196
- # activation
197
- if self.activation is not None:
198
- out = self.activation(out)
199
- return out
200
-
201
-
202
- class ResUpBlock(nn.Module):
203
- """Residual block with upsampling.
204
-
205
- Args:
206
- in_channels (int): Channel number of the input.
207
- out_channels (int): Channel number of the output.
208
- """
209
-
210
- def __init__(self, in_channels, out_channels):
211
- super(ResUpBlock, self).__init__()
212
-
213
- self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
214
- self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
215
- self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
216
-
217
- def forward(self, x):
218
- out = self.conv1(x)
219
- out = self.conv2(out)
220
- skip = self.skip(x)
221
- out = (out + skip) / math.sqrt(2)
222
- return out
223
-
224
-
225
- @ARCH_REGISTRY.register()
226
- class GFPGANv1(nn.Module):
227
- """Unet + StyleGAN2 decoder with SFT."""
228
-
229
- def __init__(
230
- self,
231
- out_size,
232
- num_style_feat=512,
233
- channel_multiplier=1,
234
- resample_kernel=(1, 3, 3, 1),
235
- decoder_load_path=None,
236
- fix_decoder=True,
237
- # for stylegan decoder
238
- num_mlp=8,
239
- lr_mlp=0.01,
240
- input_is_latent=False,
241
- different_w=False,
242
- narrow=1,
243
- sft_half=False):
244
-
245
- super(GFPGANv1, self).__init__()
246
- self.input_is_latent = input_is_latent
247
- self.different_w = different_w
248
- self.num_style_feat = num_style_feat
249
-
250
- unet_narrow = narrow * 0.5
251
- channels = {
252
- '4': int(512 * unet_narrow),
253
- '8': int(512 * unet_narrow),
254
- '16': int(512 * unet_narrow),
255
- '32': int(512 * unet_narrow),
256
- '64': int(256 * channel_multiplier * unet_narrow),
257
- '128': int(128 * channel_multiplier * unet_narrow),
258
- '256': int(64 * channel_multiplier * unet_narrow),
259
- '512': int(32 * channel_multiplier * unet_narrow),
260
- '1024': int(16 * channel_multiplier * unet_narrow)
261
- }
262
-
263
- self.log_size = int(math.log(out_size, 2))
264
- first_out_size = 2**(int(math.log(out_size, 2)))
265
-
266
- self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
267
-
268
- # downsample
269
- in_channels = channels[f'{first_out_size}']
270
- self.conv_body_down = nn.ModuleList()
271
- for i in range(self.log_size, 2, -1):
272
- out_channels = channels[f'{2**(i - 1)}']
273
- self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
274
- in_channels = out_channels
275
-
276
- self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
277
-
278
- # upsample
279
- in_channels = channels['4']
280
- self.conv_body_up = nn.ModuleList()
281
- for i in range(3, self.log_size + 1):
282
- out_channels = channels[f'{2**i}']
283
- self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
284
- in_channels = out_channels
285
-
286
- # to RGB
287
- self.toRGB = nn.ModuleList()
288
- for i in range(3, self.log_size + 1):
289
- self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
290
-
291
- if different_w:
292
- linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
293
- else:
294
- linear_out_channel = num_style_feat
295
-
296
- self.final_linear = EqualLinear(
297
- channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
298
-
299
- self.stylegan_decoder = StyleGAN2GeneratorSFT(
300
- out_size=out_size,
301
- num_style_feat=num_style_feat,
302
- num_mlp=num_mlp,
303
- channel_multiplier=channel_multiplier,
304
- resample_kernel=resample_kernel,
305
- lr_mlp=lr_mlp,
306
- narrow=narrow,
307
- sft_half=sft_half)
308
-
309
- if decoder_load_path:
310
- self.stylegan_decoder.load_state_dict(
311
- torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
312
- if fix_decoder:
313
- for _, param in self.stylegan_decoder.named_parameters():
314
- param.requires_grad = False
315
-
316
- # for SFT
317
- self.condition_scale = nn.ModuleList()
318
- self.condition_shift = nn.ModuleList()
319
- for i in range(3, self.log_size + 1):
320
- out_channels = channels[f'{2**i}']
321
- if sft_half:
322
- sft_out_channels = out_channels
323
- else:
324
- sft_out_channels = out_channels * 2
325
- self.condition_scale.append(
326
- nn.Sequential(
327
- EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
328
- ScaledLeakyReLU(0.2),
329
- EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
330
- self.condition_shift.append(
331
- nn.Sequential(
332
- EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
333
- ScaledLeakyReLU(0.2),
334
- EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
335
-
336
- def forward(self,
337
- x,
338
- return_latents=False,
339
- save_feat_path=None,
340
- load_feat_path=None,
341
- return_rgb=True,
342
- randomize_noise=True):
343
- conditions = []
344
- unet_skips = []
345
- out_rgbs = []
346
-
347
- # encoder
348
- feat = self.conv_body_first(x)
349
- for i in range(self.log_size - 2):
350
- feat = self.conv_body_down[i](feat)
351
- unet_skips.insert(0, feat)
352
-
353
- feat = self.final_conv(feat)
354
-
355
- # style code
356
- style_code = self.final_linear(feat.view(feat.size(0), -1))
357
- if self.different_w:
358
- style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
359
-
360
- # decode
361
- for i in range(self.log_size - 2):
362
- # add unet skip
363
- feat = feat + unet_skips[i]
364
- # ResUpLayer
365
- feat = self.conv_body_up[i](feat)
366
- # generate scale and shift for SFT layer
367
- scale = self.condition_scale[i](feat)
368
- conditions.append(scale.clone())
369
- shift = self.condition_shift[i](feat)
370
- conditions.append(shift.clone())
371
- # generate rgb images
372
- if return_rgb:
373
- out_rgbs.append(self.toRGB[i](feat))
374
-
375
- if save_feat_path is not None:
376
- torch.save(conditions, save_feat_path)
377
- if load_feat_path is not None:
378
- conditions = torch.load(load_feat_path)
379
- conditions = [v.cuda() for v in conditions]
380
-
381
- # decoder
382
- image, _ = self.stylegan_decoder([style_code],
383
- conditions,
384
- return_latents=return_latents,
385
- input_is_latent=self.input_is_latent,
386
- randomize_noise=randomize_noise)
387
-
388
- return image, out_rgbs
389
-
390
-
391
- @ARCH_REGISTRY.register()
392
- class FacialComponentDiscriminator(nn.Module):
393
-
394
- def __init__(self):
395
- super(FacialComponentDiscriminator, self).__init__()
396
-
397
- self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
398
- self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
399
- self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
400
- self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
401
- self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
402
- self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
403
-
404
- def forward(self, x, return_feats=False):
405
- feat = self.conv1(x)
406
- feat = self.conv3(self.conv2(feat))
407
- rlt_feats = []
408
- if return_feats:
409
- rlt_feats.append(feat.clone())
410
- feat = self.conv5(self.conv4(feat))
411
- if return_feats:
412
- rlt_feats.append(feat.clone())
413
- out = self.final_conv(feat)
414
-
415
- if return_feats:
416
- return out, rlt_feats
417
- else:
418
- return out, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- import importlib
2
- from os import path as osp
3
-
4
- from basicsr.utils import scandir
5
-
6
- # automatically scan and import dataset modules for registry
7
- # scan all the files under the data folder with '_dataset' in file names
8
- data_folder = osp.dirname(osp.abspath(__file__))
9
- dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
10
- # import all the dataset modules
11
- _dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
 
 
 
 
 
 
 
 
 
 
 
 
data/ffhq_degradation_dataset.py DELETED
@@ -1,213 +0,0 @@
1
- import cv2
2
- import math
3
- import numpy as np
4
- import os.path as osp
5
- import torch
6
- import torch.utils.data as data
7
- from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
8
- normalize)
9
-
10
- from basicsr.data import degradations as degradations
11
- from basicsr.data.data_util import paths_from_folder
12
- from basicsr.data.transforms import augment
13
- from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
14
- from basicsr.utils.registry import DATASET_REGISTRY
15
-
16
-
17
- @DATASET_REGISTRY.register()
18
- class FFHQDegradationDataset(data.Dataset):
19
-
20
- def __init__(self, opt):
21
- super(FFHQDegradationDataset, self).__init__()
22
- self.opt = opt
23
- # file client (io backend)
24
- self.file_client = None
25
- self.io_backend_opt = opt['io_backend']
26
-
27
- self.gt_folder = opt['dataroot_gt']
28
- self.mean = opt['mean']
29
- self.std = opt['std']
30
- self.out_size = opt['out_size']
31
-
32
- self.crop_components = opt.get('crop_components', False) # facial components
33
- self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
34
-
35
- if self.crop_components:
36
- self.components_list = torch.load(opt.get('component_path'))
37
-
38
- if self.io_backend_opt['type'] == 'lmdb':
39
- self.io_backend_opt['db_paths'] = self.gt_folder
40
- if not self.gt_folder.endswith('.lmdb'):
41
- raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
42
- with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
43
- self.paths = [line.split('.')[0] for line in fin]
44
- else:
45
- self.paths = paths_from_folder(self.gt_folder)
46
-
47
- # degradations
48
- self.blur_kernel_size = opt['blur_kernel_size']
49
- self.kernel_list = opt['kernel_list']
50
- self.kernel_prob = opt['kernel_prob']
51
- self.blur_sigma = opt['blur_sigma']
52
- self.downsample_range = opt['downsample_range']
53
- self.noise_range = opt['noise_range']
54
- self.jpeg_range = opt['jpeg_range']
55
-
56
- # color jitter
57
- self.color_jitter_prob = opt.get('color_jitter_prob')
58
- self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
59
- self.color_jitter_shift = opt.get('color_jitter_shift', 20)
60
- # to gray
61
- self.gray_prob = opt.get('gray_prob')
62
-
63
- logger = get_root_logger()
64
- logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
65
- f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
66
- logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
67
- logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
68
- logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
69
-
70
- if self.color_jitter_prob is not None:
71
- logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
72
- f'shift: {self.color_jitter_shift}')
73
- if self.gray_prob is not None:
74
- logger.info(f'Use random gray. Prob: {self.gray_prob}')
75
-
76
- self.color_jitter_shift /= 255.
77
-
78
- @staticmethod
79
- def color_jitter(img, shift):
80
- jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
81
- img = img + jitter_val
82
- img = np.clip(img, 0, 1)
83
- return img
84
-
85
- @staticmethod
86
- def color_jitter_pt(img, brightness, contrast, saturation, hue):
87
- fn_idx = torch.randperm(4)
88
- for fn_id in fn_idx:
89
- if fn_id == 0 and brightness is not None:
90
- brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
91
- img = adjust_brightness(img, brightness_factor)
92
-
93
- if fn_id == 1 and contrast is not None:
94
- contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
95
- img = adjust_contrast(img, contrast_factor)
96
-
97
- if fn_id == 2 and saturation is not None:
98
- saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
99
- img = adjust_saturation(img, saturation_factor)
100
-
101
- if fn_id == 3 and hue is not None:
102
- hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
103
- img = adjust_hue(img, hue_factor)
104
- return img
105
-
106
- def get_component_coordinates(self, index, status):
107
- components_bbox = self.components_list[f'{index:08d}']
108
- if status[0]: # hflip
109
- # exchange right and left eye
110
- tmp = components_bbox['left_eye']
111
- components_bbox['left_eye'] = components_bbox['right_eye']
112
- components_bbox['right_eye'] = tmp
113
- # modify the width coordinate
114
- components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
115
- components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
116
- components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
117
-
118
- # get coordinates
119
- locations = []
120
- for part in ['left_eye', 'right_eye', 'mouth']:
121
- mean = components_bbox[part][0:2]
122
- half_len = components_bbox[part][2]
123
- if 'eye' in part:
124
- half_len *= self.eye_enlarge_ratio
125
- loc = np.hstack((mean - half_len + 1, mean + half_len))
126
- loc = torch.from_numpy(loc).float()
127
- locations.append(loc)
128
- return locations
129
-
130
- def __getitem__(self, index):
131
- if self.file_client is None:
132
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
133
-
134
- # load gt image
135
- gt_path = self.paths[index]
136
- img_bytes = self.file_client.get(gt_path)
137
- img_gt = imfrombytes(img_bytes, float32=True)
138
-
139
- # random horizontal flip
140
- img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
141
- h, w, _ = img_gt.shape
142
-
143
- if self.crop_components:
144
- locations = self.get_component_coordinates(index, status)
145
- loc_left_eye, loc_right_eye, loc_mouth = locations
146
-
147
- # ------------------------ generate lq image ------------------------ #
148
- # blur
149
- kernel = degradations.random_mixed_kernels(
150
- self.kernel_list,
151
- self.kernel_prob,
152
- self.blur_kernel_size,
153
- self.blur_sigma,
154
- self.blur_sigma, [-math.pi, math.pi],
155
- noise_range=None)
156
- img_lq = cv2.filter2D(img_gt, -1, kernel)
157
- # downsample
158
- scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
159
- img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
160
- # noise
161
- if self.noise_range is not None:
162
- img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
163
- # jpeg compression
164
- if self.jpeg_range is not None:
165
- img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
166
-
167
- # resize to original size
168
- img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
169
-
170
- # random color jitter (only for lq)
171
- if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
172
- img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
173
- # random to gray (only for lq)
174
- if self.gray_prob and np.random.uniform() < self.gray_prob:
175
- img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
176
- img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
177
- if self.opt.get('gt_gray'):
178
- img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
179
- img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
180
-
181
- # BGR to RGB, HWC to CHW, numpy to tensor
182
- img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
183
-
184
- # random color jitter (pytorch version) (only for lq)
185
- if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
186
- brightness = self.opt.get('brightness', (0.5, 1.5))
187
- contrast = self.opt.get('contrast', (0.5, 1.5))
188
- saturation = self.opt.get('saturation', (0, 1.5))
189
- hue = self.opt.get('hue', (-0.1, 0.1))
190
- img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
191
-
192
- # round and clip
193
- img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
194
-
195
- # normalize
196
- normalize(img_gt, self.mean, self.std, inplace=True)
197
- normalize(img_lq, self.mean, self.std, inplace=True)
198
-
199
- if self.crop_components:
200
- return_dict = {
201
- 'lq': img_lq,
202
- 'gt': img_gt,
203
- 'gt_path': gt_path,
204
- 'loc_left_eye': loc_left_eye,
205
- 'loc_right_eye': loc_right_eye,
206
- 'loc_mouth': loc_mouth
207
- }
208
- return return_dict
209
- else:
210
- return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
211
-
212
- def __len__(self):
213
- return len(self.paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/pretrained_models/GFPGANv1.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6db3a33dd00dd427b8a70a7e3c6244a5bcccb818736f4861ce1d609024a991de
3
- size 615378983
 
 
 
 
experiments/pretrained_models/README.md DELETED
@@ -1,7 +0,0 @@
1
- # Pre-trained Models and Other Data
2
-
3
- Download pre-trained models and other data. Put them in this folder.
4
-
5
- 1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
6
- 1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
7
- 1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
 
 
 
 
 
 
 
 
inference_gfpgan_full.py DELETED
@@ -1,130 +0,0 @@
1
- import argparse
2
- import cv2
3
- import glob
4
- import numpy as np
5
- import os
6
- import torch
7
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
8
- from torchvision.transforms.functional import normalize
9
-
10
- from archs.gfpganv1_arch import GFPGANv1
11
- from basicsr.utils import img2tensor, imwrite, tensor2img
12
-
13
-
14
- def restoration(gfpgan,
15
- face_helper,
16
- img_path,
17
- save_root,
18
- has_aligned=False,
19
- only_center_face=True,
20
- suffix=None,
21
- paste_back=False):
22
- # read image
23
- img_name = os.path.basename(img_path)
24
- print(f'Processing {img_name} ...')
25
- basename, _ = os.path.splitext(img_name)
26
- input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
27
- face_helper.clean_all()
28
-
29
- if has_aligned:
30
- input_img = cv2.resize(input_img, (512, 512))
31
- face_helper.cropped_faces = [input_img]
32
- else:
33
- face_helper.read_image(input_img)
34
- # get face landmarks for each face
35
- face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
36
- # align and warp each face
37
- save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
38
- face_helper.align_warp_face(save_crop_path)
39
-
40
- # face restoration
41
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
42
- # prepare data
43
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
44
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
45
- cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
46
-
47
- try:
48
- with torch.no_grad():
49
- output = gfpgan(cropped_face_t, return_rgb=False)[0]
50
- # convert to image
51
- restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
52
- except RuntimeError as error:
53
- print(f'\tFailed inference for GFPGAN: {error}.')
54
- restored_face = cropped_face
55
-
56
- restored_face = restored_face.astype('uint8')
57
- face_helper.add_restored_face(restored_face)
58
-
59
- if suffix is not None:
60
- save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
61
- else:
62
- save_face_name = f'{basename}_{idx:02d}.png'
63
- save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
64
- imwrite(restored_face, save_restore_path)
65
-
66
- # save cmp image
67
- cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
68
- imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
69
-
70
- if not has_aligned and paste_back:
71
- face_helper.get_inverse_affine(None)
72
- save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
73
- # paste each restored face to the input image
74
- face_helper.paste_faces_to_input_image(save_restore_path)
75
-
76
-
77
- if __name__ == '__main__':
78
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
- parser = argparse.ArgumentParser()
80
-
81
- parser.add_argument('--upscale_factor', type=int, default=1)
82
- parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth')
83
- parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
84
- parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
85
- parser.add_argument('--only_center_face', action='store_true')
86
- parser.add_argument('--aligned', action='store_true')
87
- parser.add_argument('--paste_back', action='store_true')
88
-
89
- args = parser.parse_args()
90
- if args.test_path.endswith('/'):
91
- args.test_path = args.test_path[:-1]
92
- save_root = 'results/'
93
- os.makedirs(save_root, exist_ok=True)
94
-
95
- # initialize the GFP-GAN
96
- gfpgan = GFPGANv1(
97
- out_size=512,
98
- num_style_feat=512,
99
- channel_multiplier=1,
100
- decoder_load_path=None,
101
- fix_decoder=True,
102
- # for stylegan decoder
103
- num_mlp=8,
104
- input_is_latent=True,
105
- different_w=True,
106
- narrow=1,
107
- sft_half=True)
108
-
109
- gfpgan.to(device)
110
- checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
111
- gfpgan.load_state_dict(checkpoint['params_ema'])
112
- gfpgan.eval()
113
-
114
- # initialize face helper
115
- face_helper = FaceRestoreHelper(
116
- args.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
117
-
118
- img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
119
- for img_path in img_list:
120
- restoration(
121
- gfpgan,
122
- face_helper,
123
- img_path,
124
- save_root,
125
- has_aligned=args.aligned,
126
- only_center_face=args.only_center_face,
127
- suffix=args.suffix,
128
- paste_back=args.paste_back)
129
-
130
- print('Results are in the <results> folder.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- import importlib
2
- from os import path as osp
3
-
4
- from basicsr.utils import scandir
5
-
6
- # automatically scan and import model modules for registry
7
- # scan all the files under the 'models' folder and collect files ending with
8
- # '_model.py'
9
- model_folder = osp.dirname(osp.abspath(__file__))
10
- model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
11
- # import all the model modules
12
- _model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
 
 
 
 
 
 
 
 
 
 
 
 
 
models/gfpgan_model.py DELETED
@@ -1,562 +0,0 @@
1
- import math
2
- import os.path as osp
3
- import torch
4
- from collections import OrderedDict
5
- from torch.nn import functional as F
6
- from torchvision.ops import roi_align
7
- from tqdm import tqdm
8
-
9
- from basicsr.archs import build_network
10
- from basicsr.losses import build_loss
11
- from basicsr.losses.losses import r1_penalty
12
- from basicsr.metrics import calculate_metric
13
- from basicsr.models.base_model import BaseModel
14
- from basicsr.utils import get_root_logger, imwrite, tensor2img
15
- from basicsr.utils.registry import MODEL_REGISTRY
16
-
17
-
18
- @MODEL_REGISTRY.register()
19
- class GFPGANModel(BaseModel):
20
- """GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
21
-
22
- def __init__(self, opt):
23
- super(GFPGANModel, self).__init__(opt)
24
- self.idx = 0
25
-
26
- # define network
27
- self.net_g = build_network(opt['network_g'])
28
- self.net_g = self.model_to_device(self.net_g)
29
- self.print_network(self.net_g)
30
-
31
- # load pretrained model
32
- load_path = self.opt['path'].get('pretrain_network_g', None)
33
- if load_path is not None:
34
- param_key = self.opt['path'].get('param_key_g', 'params')
35
- self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
36
-
37
- self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
38
-
39
- if self.is_train:
40
- self.init_training_settings()
41
-
42
- def init_training_settings(self):
43
- train_opt = self.opt['train']
44
-
45
- # ----------- define net_d ----------- #
46
- self.net_d = build_network(self.opt['network_d'])
47
- self.net_d = self.model_to_device(self.net_d)
48
- self.print_network(self.net_d)
49
- # load pretrained model
50
- load_path = self.opt['path'].get('pretrain_network_d', None)
51
- if load_path is not None:
52
- self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
53
-
54
- # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
55
- # net_g_ema only used for testing on one GPU and saving
56
- # There is no need to wrap with DistributedDataParallel
57
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
58
- # load pretrained model
59
- load_path = self.opt['path'].get('pretrain_network_g', None)
60
- if load_path is not None:
61
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
62
- else:
63
- self.model_ema(0) # copy net_g weight
64
-
65
- self.net_g.train()
66
- self.net_d.train()
67
- self.net_g_ema.eval()
68
-
69
- # ----------- facial components networks ----------- #
70
- if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
71
- self.use_facial_disc = True
72
- else:
73
- self.use_facial_disc = False
74
-
75
- if self.use_facial_disc:
76
- # left eye
77
- self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
78
- self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
79
- self.print_network(self.net_d_left_eye)
80
- load_path = self.opt['path'].get('pretrain_network_d_left_eye')
81
- if load_path is not None:
82
- self.load_network(self.net_d_left_eye, load_path, True, 'params')
83
- # right eye
84
- self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
85
- self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
86
- self.print_network(self.net_d_right_eye)
87
- load_path = self.opt['path'].get('pretrain_network_d_right_eye')
88
- if load_path is not None:
89
- self.load_network(self.net_d_right_eye, load_path, True, 'params')
90
- # mouth
91
- self.net_d_mouth = build_network(self.opt['network_d_mouth'])
92
- self.net_d_mouth = self.model_to_device(self.net_d_mouth)
93
- self.print_network(self.net_d_mouth)
94
- load_path = self.opt['path'].get('pretrain_network_d_mouth')
95
- if load_path is not None:
96
- self.load_network(self.net_d_mouth, load_path, True, 'params')
97
-
98
- self.net_d_left_eye.train()
99
- self.net_d_right_eye.train()
100
- self.net_d_mouth.train()
101
-
102
- # ----------- define facial component gan loss ----------- #
103
- self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
104
-
105
- # ----------- define losses ----------- #
106
- if train_opt.get('pixel_opt'):
107
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
108
- else:
109
- self.cri_pix = None
110
-
111
- if train_opt.get('perceptual_opt'):
112
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
113
- else:
114
- self.cri_perceptual = None
115
-
116
- # L1 loss used in pyramid loss, component style loss and identity loss
117
- self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
118
-
119
- # gan loss (wgan)
120
- self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
121
-
122
- # ----------- define identity loss ----------- #
123
- if 'network_identity' in self.opt:
124
- self.use_identity = True
125
- else:
126
- self.use_identity = False
127
-
128
- if self.use_identity:
129
- # define identity network
130
- self.network_identity = build_network(self.opt['network_identity'])
131
- self.network_identity = self.model_to_device(self.network_identity)
132
- self.print_network(self.network_identity)
133
- load_path = self.opt['path'].get('pretrain_network_identity')
134
- if load_path is not None:
135
- self.load_network(self.network_identity, load_path, True, None)
136
- self.network_identity.eval()
137
- for param in self.network_identity.parameters():
138
- param.requires_grad = False
139
-
140
- # regularization weights
141
- self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
142
- self.net_d_iters = train_opt.get('net_d_iters', 1)
143
- self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
144
- self.net_d_reg_every = train_opt['net_d_reg_every']
145
-
146
- # set up optimizers and schedulers
147
- self.setup_optimizers()
148
- self.setup_schedulers()
149
-
150
- def setup_optimizers(self):
151
- train_opt = self.opt['train']
152
-
153
- # ----------- optimizer g ----------- #
154
- net_g_reg_ratio = 1
155
- normal_params = []
156
- for _, param in self.net_g.named_parameters():
157
- normal_params.append(param)
158
- optim_params_g = [{ # add normal params first
159
- 'params': normal_params,
160
- 'lr': train_opt['optim_g']['lr']
161
- }]
162
- optim_type = train_opt['optim_g'].pop('type')
163
- lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
164
- betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
165
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
166
- self.optimizers.append(self.optimizer_g)
167
-
168
- # ----------- optimizer d ----------- #
169
- net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
170
- normal_params = []
171
- for _, param in self.net_d.named_parameters():
172
- normal_params.append(param)
173
- optim_params_d = [{ # add normal params first
174
- 'params': normal_params,
175
- 'lr': train_opt['optim_d']['lr']
176
- }]
177
- optim_type = train_opt['optim_d'].pop('type')
178
- lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
179
- betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
180
- self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
181
- self.optimizers.append(self.optimizer_d)
182
-
183
- if self.use_facial_disc:
184
- # setup optimizers for facial component discriminators
185
- optim_type = train_opt['optim_component'].pop('type')
186
- lr = train_opt['optim_component']['lr']
187
- # left eye
188
- self.optimizer_d_left_eye = self.get_optimizer(
189
- optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
190
- self.optimizers.append(self.optimizer_d_left_eye)
191
- # right eye
192
- self.optimizer_d_right_eye = self.get_optimizer(
193
- optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
194
- self.optimizers.append(self.optimizer_d_right_eye)
195
- # mouth
196
- self.optimizer_d_mouth = self.get_optimizer(
197
- optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
198
- self.optimizers.append(self.optimizer_d_mouth)
199
-
200
- def feed_data(self, data):
201
- self.lq = data['lq'].to(self.device)
202
- if 'gt' in data:
203
- self.gt = data['gt'].to(self.device)
204
-
205
- if 'loc_left_eye' in data:
206
- # get facial component locations, shape (batch, 4)
207
- self.loc_left_eyes = data['loc_left_eye']
208
- self.loc_right_eyes = data['loc_right_eye']
209
- self.loc_mouths = data['loc_mouth']
210
-
211
- # uncomment to check data
212
- # import torchvision
213
- # if self.opt['rank'] == 0:
214
- # import os
215
- # os.makedirs('tmp/gt', exist_ok=True)
216
- # os.makedirs('tmp/lq', exist_ok=True)
217
- # print(self.idx)
218
- # torchvision.utils.save_image(
219
- # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
220
- # torchvision.utils.save_image(
221
- # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
222
- # self.idx = self.idx + 1
223
-
224
- def construct_img_pyramid(self):
225
- pyramid_gt = [self.gt]
226
- down_img = self.gt
227
- for _ in range(0, self.log_size - 3):
228
- down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
229
- pyramid_gt.insert(0, down_img)
230
- return pyramid_gt
231
-
232
- def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
233
- # hard code
234
- face_ratio = int(self.opt['network_g']['out_size'] / 512)
235
- eye_out_size *= face_ratio
236
- mouth_out_size *= face_ratio
237
-
238
- rois_eyes = []
239
- rois_mouths = []
240
- for b in range(self.loc_left_eyes.size(0)): # loop for batch size
241
- # left eye and right eye
242
- img_inds = self.loc_left_eyes.new_full((2, 1), b)
243
- bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
244
- rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
245
- rois_eyes.append(rois)
246
- # mouse
247
- img_inds = self.loc_left_eyes.new_full((1, 1), b)
248
- rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
249
- rois_mouths.append(rois)
250
-
251
- rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
252
- rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
253
-
254
- # real images
255
- all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
256
- self.left_eyes_gt = all_eyes[0::2, :, :, :]
257
- self.right_eyes_gt = all_eyes[1::2, :, :, :]
258
- self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
259
- # output
260
- all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
261
- self.left_eyes = all_eyes[0::2, :, :, :]
262
- self.right_eyes = all_eyes[1::2, :, :, :]
263
- self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
264
-
265
- def _gram_mat(self, x):
266
- """Calculate Gram matrix.
267
-
268
- Args:
269
- x (torch.Tensor): Tensor with shape of (n, c, h, w).
270
-
271
- Returns:
272
- torch.Tensor: Gram matrix.
273
- """
274
- n, c, h, w = x.size()
275
- features = x.view(n, c, w * h)
276
- features_t = features.transpose(1, 2)
277
- gram = features.bmm(features_t) / (c * h * w)
278
- return gram
279
-
280
- def gray_resize_for_identity(self, out, size=128):
281
- out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
282
- out_gray = out_gray.unsqueeze(1)
283
- out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
284
- return out_gray
285
-
286
- def optimize_parameters(self, current_iter):
287
- # optimize net_g
288
- for p in self.net_d.parameters():
289
- p.requires_grad = False
290
- self.optimizer_g.zero_grad()
291
-
292
- if self.use_facial_disc:
293
- for p in self.net_d_left_eye.parameters():
294
- p.requires_grad = False
295
- for p in self.net_d_right_eye.parameters():
296
- p.requires_grad = False
297
- for p in self.net_d_mouth.parameters():
298
- p.requires_grad = False
299
-
300
- # image pyramid loss weight
301
- if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
302
- pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
303
- else:
304
- pyramid_loss_weight = 1e-12 # very small loss
305
- if pyramid_loss_weight > 0:
306
- self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
307
- pyramid_gt = self.construct_img_pyramid()
308
- else:
309
- self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
310
-
311
- # get roi-align regions
312
- if self.use_facial_disc:
313
- self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
314
-
315
- l_g_total = 0
316
- loss_dict = OrderedDict()
317
- if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
318
- # pixel loss
319
- if self.cri_pix:
320
- l_g_pix = self.cri_pix(self.output, self.gt)
321
- l_g_total += l_g_pix
322
- loss_dict['l_g_pix'] = l_g_pix
323
-
324
- # image pyramid loss
325
- if pyramid_loss_weight > 0:
326
- for i in range(0, self.log_size - 2):
327
- l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
328
- l_g_total += l_pyramid
329
- loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
330
-
331
- # perceptual loss
332
- if self.cri_perceptual:
333
- l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
334
- if l_g_percep is not None:
335
- l_g_total += l_g_percep
336
- loss_dict['l_g_percep'] = l_g_percep
337
- if l_g_style is not None:
338
- l_g_total += l_g_style
339
- loss_dict['l_g_style'] = l_g_style
340
-
341
- # gan loss
342
- fake_g_pred = self.net_d(self.output)
343
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
344
- l_g_total += l_g_gan
345
- loss_dict['l_g_gan'] = l_g_gan
346
-
347
- # facial component loss
348
- if self.use_facial_disc:
349
- # left eye
350
- fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
351
- l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
352
- l_g_total += l_g_gan
353
- loss_dict['l_g_gan_left_eye'] = l_g_gan
354
- # right eye
355
- fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
356
- l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
357
- l_g_total += l_g_gan
358
- loss_dict['l_g_gan_right_eye'] = l_g_gan
359
- # mouth
360
- fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
361
- l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
362
- l_g_total += l_g_gan
363
- loss_dict['l_g_gan_mouth'] = l_g_gan
364
-
365
- if self.opt['train'].get('comp_style_weight', 0) > 0:
366
- # get gt feat
367
- _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
368
- _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
369
- _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
370
-
371
- def _comp_style(feat, feat_gt, criterion):
372
- return criterion(self._gram_mat(feat[0]), self._gram_mat(
373
- feat_gt[0].detach())) * 0.5 + criterion(
374
- self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
375
-
376
- # facial component style loss
377
- comp_style_loss = 0
378
- comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
379
- comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
380
- comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
381
- comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
382
- l_g_total += comp_style_loss
383
- loss_dict['l_g_comp_style_loss'] = comp_style_loss
384
-
385
- # identity loss
386
- if self.use_identity:
387
- identity_weight = self.opt['train']['identity_weight']
388
- # get gray images and resize
389
- out_gray = self.gray_resize_for_identity(self.output)
390
- gt_gray = self.gray_resize_for_identity(self.gt)
391
-
392
- identity_gt = self.network_identity(gt_gray).detach()
393
- identity_out = self.network_identity(out_gray)
394
- l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
395
- l_g_total += l_identity
396
- loss_dict['l_identity'] = l_identity
397
-
398
- l_g_total.backward()
399
- self.optimizer_g.step()
400
-
401
- # EMA
402
- self.model_ema(decay=0.5**(32 / (10 * 1000)))
403
-
404
- # ----------- optimize net_d ----------- #
405
- for p in self.net_d.parameters():
406
- p.requires_grad = True
407
- self.optimizer_d.zero_grad()
408
- if self.use_facial_disc:
409
- for p in self.net_d_left_eye.parameters():
410
- p.requires_grad = True
411
- for p in self.net_d_right_eye.parameters():
412
- p.requires_grad = True
413
- for p in self.net_d_mouth.parameters():
414
- p.requires_grad = True
415
- self.optimizer_d_left_eye.zero_grad()
416
- self.optimizer_d_right_eye.zero_grad()
417
- self.optimizer_d_mouth.zero_grad()
418
-
419
- fake_d_pred = self.net_d(self.output.detach())
420
- real_d_pred = self.net_d(self.gt)
421
- l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
422
- loss_dict['l_d'] = l_d
423
- # In wgan, real_score should be positive and fake_score should benegative
424
- loss_dict['real_score'] = real_d_pred.detach().mean()
425
- loss_dict['fake_score'] = fake_d_pred.detach().mean()
426
- l_d.backward()
427
-
428
- if current_iter % self.net_d_reg_every == 0:
429
- self.gt.requires_grad = True
430
- real_pred = self.net_d(self.gt)
431
- l_d_r1 = r1_penalty(real_pred, self.gt)
432
- l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
433
- loss_dict['l_d_r1'] = l_d_r1.detach().mean()
434
- l_d_r1.backward()
435
-
436
- self.optimizer_d.step()
437
-
438
- if self.use_facial_disc:
439
- # lefe eye
440
- fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
441
- real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
442
- l_d_left_eye = self.cri_component(
443
- real_d_pred, True, is_disc=True) + self.cri_gan(
444
- fake_d_pred, False, is_disc=True)
445
- loss_dict['l_d_left_eye'] = l_d_left_eye
446
- l_d_left_eye.backward()
447
- # right eye
448
- fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
449
- real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
450
- l_d_right_eye = self.cri_component(
451
- real_d_pred, True, is_disc=True) + self.cri_gan(
452
- fake_d_pred, False, is_disc=True)
453
- loss_dict['l_d_right_eye'] = l_d_right_eye
454
- l_d_right_eye.backward()
455
- # mouth
456
- fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
457
- real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
458
- l_d_mouth = self.cri_component(
459
- real_d_pred, True, is_disc=True) + self.cri_gan(
460
- fake_d_pred, False, is_disc=True)
461
- loss_dict['l_d_mouth'] = l_d_mouth
462
- l_d_mouth.backward()
463
-
464
- self.optimizer_d_left_eye.step()
465
- self.optimizer_d_right_eye.step()
466
- self.optimizer_d_mouth.step()
467
-
468
- self.log_dict = self.reduce_loss_dict(loss_dict)
469
-
470
- def test(self):
471
- with torch.no_grad():
472
- if hasattr(self, 'net_g_ema'):
473
- self.net_g_ema.eval()
474
- self.output, _ = self.net_g_ema(self.lq)
475
- else:
476
- logger = get_root_logger()
477
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
478
- self.net_g.eval()
479
- self.output, _ = self.net_g(self.lq)
480
- self.net_g.train()
481
-
482
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
483
- if self.opt['rank'] == 0:
484
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
485
-
486
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
487
- dataset_name = dataloader.dataset.opt['name']
488
- with_metrics = self.opt['val'].get('metrics') is not None
489
- if with_metrics:
490
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
491
- pbar = tqdm(total=len(dataloader), unit='image')
492
-
493
- for idx, val_data in enumerate(dataloader):
494
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
495
- self.feed_data(val_data)
496
- self.test()
497
-
498
- visuals = self.get_current_visuals()
499
- sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
500
- gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
501
-
502
- if 'gt' in visuals:
503
- gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
504
- del self.gt
505
- # tentative for out of GPU memory
506
- del self.lq
507
- del self.output
508
- torch.cuda.empty_cache()
509
-
510
- if save_img:
511
- if self.opt['is_train']:
512
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
513
- f'{img_name}_{current_iter}.png')
514
- else:
515
- if self.opt['val']['suffix']:
516
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
517
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
518
- else:
519
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
520
- f'{img_name}_{self.opt["name"]}.png')
521
- imwrite(sr_img, save_img_path)
522
-
523
- if with_metrics:
524
- # calculate metrics
525
- for name, opt_ in self.opt['val']['metrics'].items():
526
- metric_data = dict(img1=sr_img, img2=gt_img)
527
- self.metric_results[name] += calculate_metric(metric_data, opt_)
528
- pbar.update(1)
529
- pbar.set_description(f'Test {img_name}')
530
- pbar.close()
531
-
532
- if with_metrics:
533
- for metric in self.metric_results.keys():
534
- self.metric_results[metric] /= (idx + 1)
535
-
536
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
537
-
538
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
539
- log_str = f'Validation {dataset_name}\n'
540
- for metric, value in self.metric_results.items():
541
- log_str += f'\t # {metric}: {value:.4f}\n'
542
- logger = get_root_logger()
543
- logger.info(log_str)
544
- if tb_logger:
545
- for metric, value in self.metric_results.items():
546
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
547
-
548
- def get_current_visuals(self):
549
- out_dict = OrderedDict()
550
- out_dict['gt'] = self.gt.detach().cpu()
551
- out_dict['sr'] = self.output.detach().cpu()
552
- return out_dict
553
-
554
- def save(self, epoch, current_iter):
555
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
556
- self.save_network(self.net_d, 'net_d', current_iter)
557
- # save component discriminators
558
- if self.use_facial_disc:
559
- self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
560
- self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
561
- self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
562
- self.save_training_state(epoch, current_iter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,10 +0,0 @@
1
- facexlib
2
- lmdb
3
- numpy
4
- opencv-python
5
- pyyaml
6
- tb-nightly
7
- torch>=1.7
8
- torchvision
9
- tqdm
10
- yapf
 
 
 
 
 
 
 
 
 
 
 
setup.cfg DELETED
@@ -1,22 +0,0 @@
1
- [flake8]
2
- ignore =
3
- # line break before binary operator (W503)
4
- W503,
5
- # line break after binary operator (W504)
6
- W504,
7
- max-line-length=120
8
-
9
- [yapf]
10
- based_on_style = pep8
11
- column_limit = 120
12
- blank_line_before_nested_class_or_def = true
13
- split_before_expression_after_opening_paren = true
14
-
15
- [isort]
16
- line_length = 120
17
- multi_line_output = 0
18
- known_standard_library = pkg_resources,setuptools
19
- known_first_party = basicsr
20
- known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
21
- no_lines_before = STDLIB,LOCALFOLDER
22
- default_section = THIRDPARTY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,10 +0,0 @@
1
- import os.path as osp
2
-
3
- import archs # noqa: F401
4
- import data # noqa: F401
5
- import models # noqa: F401
6
- from basicsr.train import train_pipeline
7
-
8
- if __name__ == '__main__':
9
- root_path = osp.abspath(osp.join(__file__, osp.pardir))
10
- train_pipeline(root_path)
 
 
 
 
 
 
 
 
 
 
 
train_gfpgan_v1.yml DELETED
@@ -1,210 +0,0 @@
1
- # general settings
2
- name: train_GFPGANv1_512
3
- model_type: GFPGANModel
4
- num_gpu: 4
5
- manual_seed: 0
6
-
7
- # dataset and data loader settings
8
- datasets:
9
- train:
10
- name: FFHQ
11
- type: FFHQDegradationDataset
12
- # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
13
- dataroot_gt: datasets/ffhq/ffhq_512
14
- io_backend:
15
- # type: lmdb
16
- type: disk
17
-
18
- use_hflip: true
19
- mean: [0.5, 0.5, 0.5]
20
- std: [0.5, 0.5, 0.5]
21
- out_size: 512
22
-
23
- blur_kernel_size: 41
24
- kernel_list: ['iso', 'aniso']
25
- kernel_prob: [0.5, 0.5]
26
- blur_sigma: [0.1, 10]
27
- downsample_range: [0.8, 8]
28
- noise_range: [0, 20]
29
- jpeg_range: [60, 100]
30
-
31
- # color jitter and gray
32
- color_jitter_prob: 0.3
33
- color_jitter_shift: 20
34
- color_jitter_pt_prob: 0.3
35
- gray_prob: 0.01
36
-
37
- crop_components: true
38
- component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
39
- eye_enlarge_ratio: 1.4
40
-
41
- # data loader
42
- use_shuffle: true
43
- num_worker_per_gpu: 6
44
- batch_size_per_gpu: 3
45
- dataset_enlarge_ratio: 100
46
- prefetch_mode: ~
47
-
48
- val:
49
- # Please modify accordingly to use your own validation
50
- # Or comment the val block if do not need validation during training
51
- name: validation
52
- type: PairedImageDataset
53
- dataroot_lq: datasets/faces/validation/input
54
- dataroot_gt: datasets/faces/validation/reference
55
- io_backend:
56
- type: disk
57
- mean: [0.5, 0.5, 0.5]
58
- std: [0.5, 0.5, 0.5]
59
- scale: 1
60
-
61
- # network structures
62
- network_g:
63
- type: GFPGANv1
64
- out_size: 512
65
- num_style_feat: 512
66
- channel_multiplier: 1
67
- resample_kernel: [1, 3, 3, 1]
68
- decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
69
- fix_decoder: true
70
- num_mlp: 8
71
- lr_mlp: 0.01
72
- input_is_latent: true
73
- different_w: true
74
- narrow: 1
75
- sft_half: true
76
-
77
- network_d:
78
- type: StyleGAN2Discriminator
79
- out_size: 512
80
- channel_multiplier: 1
81
- resample_kernel: [1, 3, 3, 1]
82
-
83
- network_d_left_eye:
84
- type: FacialComponentDiscriminator
85
-
86
- network_d_right_eye:
87
- type: FacialComponentDiscriminator
88
-
89
- network_d_mouth:
90
- type: FacialComponentDiscriminator
91
-
92
- network_identity:
93
- type: ResNetArcFace
94
- block: IRBlock
95
- layers: [2, 2, 2, 2]
96
- use_se: False
97
-
98
- # path
99
- path:
100
- pretrain_network_g: ~
101
- param_key_g: params_ema
102
- strict_load_g: ~
103
- pretrain_network_d: ~
104
- pretrain_network_d_left_eye: ~
105
- pretrain_network_d_right_eye: ~
106
- pretrain_network_d_mouth: ~
107
- pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
108
- # resume
109
- resume_state: ~
110
- ignore_resume_networks: ['network_identity']
111
-
112
- # training settings
113
- train:
114
- optim_g:
115
- type: Adam
116
- lr: !!float 2e-3
117
- optim_d:
118
- type: Adam
119
- lr: !!float 2e-3
120
- optim_component:
121
- type: Adam
122
- lr: !!float 2e-3
123
-
124
- scheduler:
125
- type: MultiStepLR
126
- milestones: [600000, 700000]
127
- gamma: 0.5
128
-
129
- total_iter: 800000
130
- warmup_iter: -1 # no warm up
131
-
132
- # losses
133
- # pixel loss
134
- pixel_opt:
135
- type: L1Loss
136
- loss_weight: !!float 1e-1
137
- reduction: mean
138
- # L1 loss used in pyramid loss, component style loss and identity loss
139
- L1_opt:
140
- type: L1Loss
141
- loss_weight: 1
142
- reduction: mean
143
-
144
- # image pyramid loss
145
- pyramid_loss_weight: 1
146
- remove_pyramid_loss: 50000
147
- # perceptual loss (content and style losses)
148
- perceptual_opt:
149
- type: PerceptualLoss
150
- layer_weights:
151
- # before relu
152
- 'conv1_2': 0.1
153
- 'conv2_2': 0.1
154
- 'conv3_4': 1
155
- 'conv4_4': 1
156
- 'conv5_4': 1
157
- vgg_type: vgg19
158
- use_input_norm: true
159
- perceptual_weight: !!float 1
160
- style_weight: 50
161
- range_norm: true
162
- criterion: l1
163
- # gan loss
164
- gan_opt:
165
- type: GANLoss
166
- gan_type: wgan_softplus
167
- loss_weight: !!float 1e-1
168
- # r1 regularization for discriminator
169
- r1_reg_weight: 10
170
- # facial component loss
171
- gan_component_opt:
172
- type: GANLoss
173
- gan_type: vanilla
174
- real_label_val: 1.0
175
- fake_label_val: 0.0
176
- loss_weight: !!float 1
177
- comp_style_weight: 200
178
- # identity loss
179
- identity_weight: 10
180
-
181
- net_d_iters: 1
182
- net_d_init_iters: 0
183
- net_d_reg_every: 16
184
-
185
- # validation settings
186
- val:
187
- val_freq: !!float 5e3
188
- save_img: true
189
-
190
- metrics:
191
- psnr: # metric name, can be arbitrary
192
- type: calculate_psnr
193
- crop_border: 0
194
- test_y_channel: false
195
-
196
- # logging settings
197
- logger:
198
- print_freq: 100
199
- save_checkpoint_freq: !!float 5e3
200
- use_tb_logger: true
201
- wandb:
202
- project: ~
203
- resume_id: ~
204
-
205
- # dist training settings
206
- dist_params:
207
- backend: nccl
208
- port: 29500
209
-
210
- find_unused_parameters: true