AK391
commited on
Commit
•
13fd34d
1
Parent(s):
5e8cde2
update files
Browse files- VERSION +1 -1
- gfpgan/__init__.py +2 -1
- gfpgan/archs/arcface_arch.py +57 -9
- gfpgan/archs/gfpganv1_arch.py +66 -44
- gfpgan/archs/gfpganv1_clean_arch.py +58 -38
- gfpgan/archs/stylegan2_clean_arch.py +20 -29
- gfpgan/data/ffhq_degradation_dataset.py +27 -9
- gfpgan/models/gfpgan_model.py +48 -29
- gfpgan/utils.py +23 -27
- inference_gfpgan.py +37 -19
- options/train_gfpgan_v1.yml +2 -2
- options/train_gfpgan_v1_simple.yml +2 -36
- requirements.txt +7 -13
- scripts/parse_landmark.py +16 -8
- setup.cfg +12 -1
- setup.py +0 -6
- tests/data/ffhq_gt.lmdb/data.mdb +0 -0
- experiments/.DS_Store → tests/data/ffhq_gt.lmdb/lock.mdb +0 -0
- tests/data/ffhq_gt.lmdb/meta_info.txt +1 -0
- tests/data/gt/00000000.png +0 -0
- tests/data/test_eye_mouth_landmarks.pth +3 -0
- tests/data/test_ffhq_degradation_dataset.yml +24 -0
- tests/data/test_gfpgan_model.yml +140 -0
- tests/test_arcface_arch.py +49 -0
- tests/test_ffhq_degradation_dataset.py +96 -0
- tests/test_gfpgan_arch.py +203 -0
- tests/test_gfpgan_model.py +132 -0
- tests/test_stylegan2_clean_arch.py +52 -0
- tests/test_utils.py +43 -0
VERSION
CHANGED
@@ -1 +1 @@
|
|
1 |
-
0.2.
|
|
|
1 |
+
0.2.4
|
gfpgan/__init__.py
CHANGED
@@ -3,4 +3,5 @@ from .archs import *
|
|
3 |
from .data import *
|
4 |
from .models import *
|
5 |
from .utils import *
|
6 |
-
|
|
|
|
3 |
from .data import *
|
4 |
from .models import *
|
5 |
from .utils import *
|
6 |
+
|
7 |
+
# from .version import *
|
gfpgan/archs/arcface_arch.py
CHANGED
@@ -2,13 +2,27 @@ import torch.nn as nn
|
|
2 |
from basicsr.utils.registry import ARCH_REGISTRY
|
3 |
|
4 |
|
5 |
-
def conv3x3(
|
6 |
-
"""3x3 convolution with padding
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class BasicBlock(nn.Module):
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
14 |
super(BasicBlock, self).__init__()
|
@@ -40,7 +54,16 @@ class BasicBlock(nn.Module):
|
|
40 |
|
41 |
|
42 |
class IRBlock(nn.Module):
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
46 |
super(IRBlock, self).__init__()
|
@@ -78,7 +101,15 @@ class IRBlock(nn.Module):
|
|
78 |
|
79 |
|
80 |
class Bottleneck(nn.Module):
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
84 |
super(Bottleneck, self).__init__()
|
@@ -116,10 +147,16 @@ class Bottleneck(nn.Module):
|
|
116 |
|
117 |
|
118 |
class SEBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def __init__(self, channel, reduction=16):
|
121 |
super(SEBlock, self).__init__()
|
122 |
-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
123 |
self.fc = nn.Sequential(
|
124 |
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
125 |
nn.Sigmoid())
|
@@ -133,6 +170,15 @@ class SEBlock(nn.Module):
|
|
133 |
|
134 |
@ARCH_REGISTRY.register()
|
135 |
class ResNetArcFace(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
def __init__(self, block, layers, use_se=True):
|
138 |
if block == 'IRBlock':
|
@@ -140,6 +186,7 @@ class ResNetArcFace(nn.Module):
|
|
140 |
self.inplanes = 64
|
141 |
self.use_se = use_se
|
142 |
super(ResNetArcFace, self).__init__()
|
|
|
143 |
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
144 |
self.bn1 = nn.BatchNorm2d(64)
|
145 |
self.prelu = nn.PReLU()
|
@@ -153,6 +200,7 @@ class ResNetArcFace(nn.Module):
|
|
153 |
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
154 |
self.bn5 = nn.BatchNorm1d(512)
|
155 |
|
|
|
156 |
for m in self.modules():
|
157 |
if isinstance(m, nn.Conv2d):
|
158 |
nn.init.xavier_normal_(m.weight)
|
@@ -163,7 +211,7 @@ class ResNetArcFace(nn.Module):
|
|
163 |
nn.init.xavier_normal_(m.weight)
|
164 |
nn.init.constant_(m.bias, 0)
|
165 |
|
166 |
-
def _make_layer(self, block, planes,
|
167 |
downsample = None
|
168 |
if stride != 1 or self.inplanes != planes * block.expansion:
|
169 |
downsample = nn.Sequential(
|
@@ -173,7 +221,7 @@ class ResNetArcFace(nn.Module):
|
|
173 |
layers = []
|
174 |
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
175 |
self.inplanes = planes
|
176 |
-
for _ in range(1,
|
177 |
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
178 |
|
179 |
return nn.Sequential(*layers)
|
|
|
2 |
from basicsr.utils.registry import ARCH_REGISTRY
|
3 |
|
4 |
|
5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
inplanes (int): Channel number of inputs.
|
10 |
+
outplanes (int): Channel number of outputs.
|
11 |
+
stride (int): Stride in convolution. Default: 1.
|
12 |
+
"""
|
13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
14 |
|
15 |
|
16 |
class BasicBlock(nn.Module):
|
17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
inplanes (int): Channel number of inputs.
|
21 |
+
planes (int): Channel number of outputs.
|
22 |
+
stride (int): Stride in convolution. Default: 1.
|
23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
24 |
+
"""
|
25 |
+
expansion = 1 # output channel expansion ratio
|
26 |
|
27 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
super(BasicBlock, self).__init__()
|
|
|
54 |
|
55 |
|
56 |
class IRBlock(nn.Module):
|
57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
inplanes (int): Channel number of inputs.
|
61 |
+
planes (int): Channel number of outputs.
|
62 |
+
stride (int): Stride in convolution. Default: 1.
|
63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
65 |
+
"""
|
66 |
+
expansion = 1 # output channel expansion ratio
|
67 |
|
68 |
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
69 |
super(IRBlock, self).__init__()
|
|
|
101 |
|
102 |
|
103 |
class Bottleneck(nn.Module):
|
104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
inplanes (int): Channel number of inputs.
|
108 |
+
planes (int): Channel number of outputs.
|
109 |
+
stride (int): Stride in convolution. Default: 1.
|
110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
111 |
+
"""
|
112 |
+
expansion = 4 # output channel expansion ratio
|
113 |
|
114 |
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
115 |
super(Bottleneck, self).__init__()
|
|
|
147 |
|
148 |
|
149 |
class SEBlock(nn.Module):
|
150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
channel (int): Channel number of inputs.
|
154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
155 |
+
"""
|
156 |
|
157 |
def __init__(self, channel, reduction=16):
|
158 |
super(SEBlock, self).__init__()
|
159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
160 |
self.fc = nn.Sequential(
|
161 |
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
162 |
nn.Sigmoid())
|
|
|
170 |
|
171 |
@ARCH_REGISTRY.register()
|
172 |
class ResNetArcFace(nn.Module):
|
173 |
+
"""ArcFace with ResNet architectures.
|
174 |
+
|
175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
block (str): Block used in the ArcFace architecture.
|
179 |
+
layers (tuple(int)): Block numbers in each layer.
|
180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
181 |
+
"""
|
182 |
|
183 |
def __init__(self, block, layers, use_se=True):
|
184 |
if block == 'IRBlock':
|
|
|
186 |
self.inplanes = 64
|
187 |
self.use_se = use_se
|
188 |
super(ResNetArcFace, self).__init__()
|
189 |
+
|
190 |
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
191 |
self.bn1 = nn.BatchNorm2d(64)
|
192 |
self.prelu = nn.PReLU()
|
|
|
200 |
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
201 |
self.bn5 = nn.BatchNorm1d(512)
|
202 |
|
203 |
+
# initialization
|
204 |
for m in self.modules():
|
205 |
if isinstance(m, nn.Conv2d):
|
206 |
nn.init.xavier_normal_(m.weight)
|
|
|
211 |
nn.init.xavier_normal_(m.weight)
|
212 |
nn.init.constant_(m.bias, 0)
|
213 |
|
214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
215 |
downsample = None
|
216 |
if stride != 1 or self.inplanes != planes * block.expansion:
|
217 |
downsample = nn.Sequential(
|
|
|
221 |
layers = []
|
222 |
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
223 |
self.inplanes = planes
|
224 |
+
for _ in range(1, num_blocks):
|
225 |
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
226 |
|
227 |
return nn.Sequential(*layers)
|
gfpgan/archs/gfpganv1_arch.py
CHANGED
@@ -10,18 +10,18 @@ from torch.nn import functional as F
|
|
10 |
|
11 |
|
12 |
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
13 |
-
"""StyleGAN2 Generator.
|
14 |
|
15 |
Args:
|
16 |
out_size (int): The spatial size of outputs.
|
17 |
num_style_feat (int): Channel number of style features. Default: 512.
|
18 |
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
19 |
-
channel_multiplier (int): Channel multiplier for large networks of
|
20 |
-
|
21 |
-
|
22 |
-
magnitude. A cross production will be applied to extent 1D resample
|
23 |
-
kenrel to 2D resample kernel. Default: [1, 3, 3, 1].
|
24 |
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
|
|
|
25 |
"""
|
26 |
|
27 |
def __init__(self,
|
@@ -53,21 +53,18 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
|
53 |
truncation_latent=None,
|
54 |
inject_index=None,
|
55 |
return_latents=False):
|
56 |
-
"""Forward function for
|
57 |
|
58 |
Args:
|
59 |
styles (list[Tensor]): Sample codes of styles.
|
60 |
-
|
61 |
-
|
62 |
noise (Tensor | None): Input noise or None. Default: None.
|
63 |
-
randomize_noise (bool): Randomize noise, used when 'noise' is
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
Default: None.
|
69 |
-
return_latents (bool): Whether to return style latents.
|
70 |
-
Default: False.
|
71 |
"""
|
72 |
# style codes -> latents with Style MLP layer
|
73 |
if not input_is_latent:
|
@@ -84,7 +81,7 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
|
84 |
for style in styles:
|
85 |
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
86 |
styles = style_truncation
|
87 |
-
# get style
|
88 |
if len(styles) == 1:
|
89 |
inject_index = self.num_latent
|
90 |
|
@@ -113,15 +110,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
|
113 |
# the conditions may have fewer levels
|
114 |
if i < len(conditions):
|
115 |
# SFT part to combine the conditions
|
116 |
-
if self.sft_half:
|
117 |
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
118 |
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
119 |
out = torch.cat([out_same, out_sft], dim=1)
|
120 |
-
else:
|
121 |
out = out * conditions[i - 1] + conditions[i]
|
122 |
|
123 |
out = conv2(out, latent[:, i + 1], noise=noise2)
|
124 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
125 |
i += 2
|
126 |
|
127 |
image = skip
|
@@ -133,17 +130,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
|
133 |
|
134 |
|
135 |
class ConvUpLayer(nn.Module):
|
136 |
-
"""
|
137 |
|
138 |
Args:
|
139 |
in_channels (int): Channel number of the input.
|
140 |
out_channels (int): Channel number of the output.
|
141 |
kernel_size (int): Size of the convolving kernel.
|
142 |
stride (int): Stride of the convolution. Default: 1
|
143 |
-
padding (int): Zero-padding added to both sides of the input.
|
144 |
-
|
145 |
-
bias (bool): If ``True``, adds a learnable bias to the output.
|
146 |
-
Default: ``True``.
|
147 |
bias_init_val (float): Bias initialized value. Default: 0.
|
148 |
activate (bool): Whether use activateion. Default: True.
|
149 |
"""
|
@@ -163,6 +158,7 @@ class ConvUpLayer(nn.Module):
|
|
163 |
self.kernel_size = kernel_size
|
164 |
self.stride = stride
|
165 |
self.padding = padding
|
|
|
166 |
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
167 |
|
168 |
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
@@ -223,7 +219,26 @@ class ResUpBlock(nn.Module):
|
|
223 |
|
224 |
@ARCH_REGISTRY.register()
|
225 |
class GFPGANv1(nn.Module):
|
226 |
-
"""Unet + StyleGAN2 decoder with SFT.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
def __init__(
|
229 |
self,
|
@@ -246,7 +261,7 @@ class GFPGANv1(nn.Module):
|
|
246 |
self.different_w = different_w
|
247 |
self.num_style_feat = num_style_feat
|
248 |
|
249 |
-
unet_narrow = narrow * 0.5
|
250 |
channels = {
|
251 |
'4': int(512 * unet_narrow),
|
252 |
'8': int(512 * unet_narrow),
|
@@ -295,6 +310,7 @@ class GFPGANv1(nn.Module):
|
|
295 |
self.final_linear = EqualLinear(
|
296 |
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
297 |
|
|
|
298 |
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
299 |
out_size=out_size,
|
300 |
num_style_feat=num_style_feat,
|
@@ -305,14 +321,16 @@ class GFPGANv1(nn.Module):
|
|
305 |
narrow=narrow,
|
306 |
sft_half=sft_half)
|
307 |
|
|
|
308 |
if decoder_load_path:
|
309 |
self.stylegan_decoder.load_state_dict(
|
310 |
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
|
|
311 |
if fix_decoder:
|
312 |
for _, param in self.stylegan_decoder.named_parameters():
|
313 |
param.requires_grad = False
|
314 |
|
315 |
-
# for SFT
|
316 |
self.condition_scale = nn.ModuleList()
|
317 |
self.condition_shift = nn.ModuleList()
|
318 |
for i in range(3, self.log_size + 1):
|
@@ -332,13 +350,15 @@ class GFPGANv1(nn.Module):
|
|
332 |
ScaledLeakyReLU(0.2),
|
333 |
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
334 |
|
335 |
-
def forward(self,
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
342 |
conditions = []
|
343 |
unet_skips = []
|
344 |
out_rgbs = []
|
@@ -362,7 +382,7 @@ class GFPGANv1(nn.Module):
|
|
362 |
feat = feat + unet_skips[i]
|
363 |
# ResUpLayer
|
364 |
feat = self.conv_body_up[i](feat)
|
365 |
-
# generate scale and shift for SFT
|
366 |
scale = self.condition_scale[i](feat)
|
367 |
conditions.append(scale.clone())
|
368 |
shift = self.condition_shift[i](feat)
|
@@ -371,12 +391,6 @@ class GFPGANv1(nn.Module):
|
|
371 |
if return_rgb:
|
372 |
out_rgbs.append(self.toRGB[i](feat))
|
373 |
|
374 |
-
if save_feat_path is not None:
|
375 |
-
torch.save(conditions, save_feat_path)
|
376 |
-
if load_feat_path is not None:
|
377 |
-
conditions = torch.load(load_feat_path)
|
378 |
-
conditions = [v.cuda() for v in conditions]
|
379 |
-
|
380 |
# decoder
|
381 |
image, _ = self.stylegan_decoder([style_code],
|
382 |
conditions,
|
@@ -389,10 +403,12 @@ class GFPGANv1(nn.Module):
|
|
389 |
|
390 |
@ARCH_REGISTRY.register()
|
391 |
class FacialComponentDiscriminator(nn.Module):
|
|
|
|
|
392 |
|
393 |
def __init__(self):
|
394 |
super(FacialComponentDiscriminator, self).__init__()
|
395 |
-
|
396 |
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
397 |
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
398 |
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
@@ -401,6 +417,12 @@ class FacialComponentDiscriminator(nn.Module):
|
|
401 |
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
402 |
|
403 |
def forward(self, x, return_feats=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
feat = self.conv1(x)
|
405 |
feat = self.conv3(self.conv2(feat))
|
406 |
rlt_feats = []
|
|
|
10 |
|
11 |
|
12 |
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
14 |
|
15 |
Args:
|
16 |
out_size (int): The spatial size of outputs.
|
17 |
num_style_feat (int): Channel number of style features. Default: 512.
|
18 |
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
19 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
20 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
21 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
|
|
|
|
22 |
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
23 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
24 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
25 |
"""
|
26 |
|
27 |
def __init__(self,
|
|
|
53 |
truncation_latent=None,
|
54 |
inject_index=None,
|
55 |
return_latents=False):
|
56 |
+
"""Forward function for StyleGAN2GeneratorSFT.
|
57 |
|
58 |
Args:
|
59 |
styles (list[Tensor]): Sample codes of styles.
|
60 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
61 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
62 |
noise (Tensor | None): Input noise or None. Default: None.
|
63 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
64 |
+
truncation (float): The truncation ratio. Default: 1.
|
65 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
66 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
67 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
|
|
|
|
|
|
68 |
"""
|
69 |
# style codes -> latents with Style MLP layer
|
70 |
if not input_is_latent:
|
|
|
81 |
for style in styles:
|
82 |
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
83 |
styles = style_truncation
|
84 |
+
# get style latents with injection
|
85 |
if len(styles) == 1:
|
86 |
inject_index = self.num_latent
|
87 |
|
|
|
110 |
# the conditions may have fewer levels
|
111 |
if i < len(conditions):
|
112 |
# SFT part to combine the conditions
|
113 |
+
if self.sft_half: # only apply SFT to half of the channels
|
114 |
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
115 |
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
116 |
out = torch.cat([out_same, out_sft], dim=1)
|
117 |
+
else: # apply SFT to all the channels
|
118 |
out = out * conditions[i - 1] + conditions[i]
|
119 |
|
120 |
out = conv2(out, latent[:, i + 1], noise=noise2)
|
121 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
122 |
i += 2
|
123 |
|
124 |
image = skip
|
|
|
130 |
|
131 |
|
132 |
class ConvUpLayer(nn.Module):
|
133 |
+
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
134 |
|
135 |
Args:
|
136 |
in_channels (int): Channel number of the input.
|
137 |
out_channels (int): Channel number of the output.
|
138 |
kernel_size (int): Size of the convolving kernel.
|
139 |
stride (int): Stride of the convolution. Default: 1
|
140 |
+
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
141 |
+
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
|
|
|
|
142 |
bias_init_val (float): Bias initialized value. Default: 0.
|
143 |
activate (bool): Whether use activateion. Default: True.
|
144 |
"""
|
|
|
158 |
self.kernel_size = kernel_size
|
159 |
self.stride = stride
|
160 |
self.padding = padding
|
161 |
+
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
162 |
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
163 |
|
164 |
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
|
|
219 |
|
220 |
@ARCH_REGISTRY.register()
|
221 |
class GFPGANv1(nn.Module):
|
222 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
223 |
+
|
224 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
out_size (int): The spatial size of outputs.
|
228 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
229 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
230 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
231 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
232 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
233 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
234 |
+
|
235 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
236 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
237 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
238 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
239 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
240 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
241 |
+
"""
|
242 |
|
243 |
def __init__(
|
244 |
self,
|
|
|
261 |
self.different_w = different_w
|
262 |
self.num_style_feat = num_style_feat
|
263 |
|
264 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
265 |
channels = {
|
266 |
'4': int(512 * unet_narrow),
|
267 |
'8': int(512 * unet_narrow),
|
|
|
310 |
self.final_linear = EqualLinear(
|
311 |
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
312 |
|
313 |
+
# the decoder: stylegan2 generator with SFT modulations
|
314 |
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
315 |
out_size=out_size,
|
316 |
num_style_feat=num_style_feat,
|
|
|
321 |
narrow=narrow,
|
322 |
sft_half=sft_half)
|
323 |
|
324 |
+
# load pre-trained stylegan2 model if necessary
|
325 |
if decoder_load_path:
|
326 |
self.stylegan_decoder.load_state_dict(
|
327 |
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
328 |
+
# fix decoder without updating params
|
329 |
if fix_decoder:
|
330 |
for _, param in self.stylegan_decoder.named_parameters():
|
331 |
param.requires_grad = False
|
332 |
|
333 |
+
# for SFT modulations (scale and shift)
|
334 |
self.condition_scale = nn.ModuleList()
|
335 |
self.condition_shift = nn.ModuleList()
|
336 |
for i in range(3, self.log_size + 1):
|
|
|
350 |
ScaledLeakyReLU(0.2),
|
351 |
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
352 |
|
353 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
354 |
+
"""Forward function for GFPGANv1.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
x (Tensor): Input images.
|
358 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
359 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
360 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
361 |
+
"""
|
362 |
conditions = []
|
363 |
unet_skips = []
|
364 |
out_rgbs = []
|
|
|
382 |
feat = feat + unet_skips[i]
|
383 |
# ResUpLayer
|
384 |
feat = self.conv_body_up[i](feat)
|
385 |
+
# generate scale and shift for SFT layers
|
386 |
scale = self.condition_scale[i](feat)
|
387 |
conditions.append(scale.clone())
|
388 |
shift = self.condition_shift[i](feat)
|
|
|
391 |
if return_rgb:
|
392 |
out_rgbs.append(self.toRGB[i](feat))
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
# decoder
|
395 |
image, _ = self.stylegan_decoder([style_code],
|
396 |
conditions,
|
|
|
403 |
|
404 |
@ARCH_REGISTRY.register()
|
405 |
class FacialComponentDiscriminator(nn.Module):
|
406 |
+
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
407 |
+
"""
|
408 |
|
409 |
def __init__(self):
|
410 |
super(FacialComponentDiscriminator, self).__init__()
|
411 |
+
# It now uses a VGG-style architectrue with fixed model size
|
412 |
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
413 |
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
414 |
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
|
|
417 |
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
418 |
|
419 |
def forward(self, x, return_feats=False):
|
420 |
+
"""Forward function for FacialComponentDiscriminator.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
x (Tensor): Input images.
|
424 |
+
return_feats (bool): Whether to return intermediate features. Default: False.
|
425 |
+
"""
|
426 |
feat = self.conv1(x)
|
427 |
feat = self.conv3(self.conv2(feat))
|
428 |
rlt_feats = []
|
gfpgan/archs/gfpganv1_clean_arch.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import math
|
2 |
import random
|
3 |
import torch
|
|
|
4 |
from torch import nn
|
5 |
from torch.nn import functional as F
|
6 |
|
@@ -8,14 +9,17 @@ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
|
8 |
|
9 |
|
10 |
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
11 |
-
"""StyleGAN2 Generator.
|
|
|
|
|
12 |
|
13 |
Args:
|
14 |
out_size (int): The spatial size of outputs.
|
15 |
num_style_feat (int): Channel number of style features. Default: 512.
|
16 |
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
17 |
-
channel_multiplier (int): Channel multiplier for large networks of
|
18 |
-
|
|
|
19 |
"""
|
20 |
|
21 |
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
@@ -25,7 +29,6 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
|
25 |
num_mlp=num_mlp,
|
26 |
channel_multiplier=channel_multiplier,
|
27 |
narrow=narrow)
|
28 |
-
|
29 |
self.sft_half = sft_half
|
30 |
|
31 |
def forward(self,
|
@@ -38,21 +41,18 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
|
38 |
truncation_latent=None,
|
39 |
inject_index=None,
|
40 |
return_latents=False):
|
41 |
-
"""Forward function for
|
42 |
|
43 |
Args:
|
44 |
styles (list[Tensor]): Sample codes of styles.
|
45 |
-
|
46 |
-
|
47 |
noise (Tensor | None): Input noise or None. Default: None.
|
48 |
-
randomize_noise (bool): Randomize noise, used when 'noise' is
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
Default: None.
|
54 |
-
return_latents (bool): Whether to return style latents.
|
55 |
-
Default: False.
|
56 |
"""
|
57 |
# style codes -> latents with Style MLP layer
|
58 |
if not input_is_latent:
|
@@ -69,7 +69,7 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
|
69 |
for style in styles:
|
70 |
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
71 |
styles = style_truncation
|
72 |
-
# get style
|
73 |
if len(styles) == 1:
|
74 |
inject_index = self.num_latent
|
75 |
|
@@ -98,15 +98,15 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
|
98 |
# the conditions may have fewer levels
|
99 |
if i < len(conditions):
|
100 |
# SFT part to combine the conditions
|
101 |
-
if self.sft_half:
|
102 |
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
103 |
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
104 |
out = torch.cat([out_same, out_sft], dim=1)
|
105 |
-
else:
|
106 |
out = out * conditions[i - 1] + conditions[i]
|
107 |
|
108 |
out = conv2(out, latent[:, i + 1], noise=noise2)
|
109 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
110 |
i += 2
|
111 |
|
112 |
image = skip
|
@@ -118,11 +118,12 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
|
118 |
|
119 |
|
120 |
class ResBlock(nn.Module):
|
121 |
-
"""Residual block with upsampling/downsampling.
|
122 |
|
123 |
Args:
|
124 |
in_channels (int): Channel number of the input.
|
125 |
out_channels (int): Channel number of the output.
|
|
|
126 |
"""
|
127 |
|
128 |
def __init__(self, in_channels, out_channels, mode='down'):
|
@@ -148,8 +149,27 @@ class ResBlock(nn.Module):
|
|
148 |
return out
|
149 |
|
150 |
|
|
|
151 |
class GFPGANv1Clean(nn.Module):
|
152 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def __init__(
|
155 |
self,
|
@@ -170,7 +190,7 @@ class GFPGANv1Clean(nn.Module):
|
|
170 |
self.different_w = different_w
|
171 |
self.num_style_feat = num_style_feat
|
172 |
|
173 |
-
unet_narrow = narrow * 0.5
|
174 |
channels = {
|
175 |
'4': int(512 * unet_narrow),
|
176 |
'8': int(512 * unet_narrow),
|
@@ -218,6 +238,7 @@ class GFPGANv1Clean(nn.Module):
|
|
218 |
|
219 |
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
220 |
|
|
|
221 |
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
222 |
out_size=out_size,
|
223 |
num_style_feat=num_style_feat,
|
@@ -226,14 +247,16 @@ class GFPGANv1Clean(nn.Module):
|
|
226 |
narrow=narrow,
|
227 |
sft_half=sft_half)
|
228 |
|
|
|
229 |
if decoder_load_path:
|
230 |
self.stylegan_decoder.load_state_dict(
|
231 |
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
|
|
232 |
if fix_decoder:
|
233 |
-
for
|
234 |
param.requires_grad = False
|
235 |
|
236 |
-
# for SFT
|
237 |
self.condition_scale = nn.ModuleList()
|
238 |
self.condition_shift = nn.ModuleList()
|
239 |
for i in range(3, self.log_size + 1):
|
@@ -251,13 +274,15 @@ class GFPGANv1Clean(nn.Module):
|
|
251 |
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
252 |
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
253 |
|
254 |
-
def forward(self,
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
261 |
conditions = []
|
262 |
unet_skips = []
|
263 |
out_rgbs = []
|
@@ -273,13 +298,14 @@ class GFPGANv1Clean(nn.Module):
|
|
273 |
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
274 |
if self.different_w:
|
275 |
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
|
|
276 |
# decode
|
277 |
for i in range(self.log_size - 2):
|
278 |
# add unet skip
|
279 |
feat = feat + unet_skips[i]
|
280 |
# ResUpLayer
|
281 |
feat = self.conv_body_up[i](feat)
|
282 |
-
# generate scale and shift for SFT
|
283 |
scale = self.condition_scale[i](feat)
|
284 |
conditions.append(scale.clone())
|
285 |
shift = self.condition_shift[i](feat)
|
@@ -288,12 +314,6 @@ class GFPGANv1Clean(nn.Module):
|
|
288 |
if return_rgb:
|
289 |
out_rgbs.append(self.toRGB[i](feat))
|
290 |
|
291 |
-
if save_feat_path is not None:
|
292 |
-
torch.save(conditions, save_feat_path)
|
293 |
-
if load_feat_path is not None:
|
294 |
-
conditions = torch.load(load_feat_path)
|
295 |
-
conditions = [v.cuda() for v in conditions]
|
296 |
-
|
297 |
# decoder
|
298 |
image, _ = self.stylegan_decoder([style_code],
|
299 |
conditions,
|
|
|
1 |
import math
|
2 |
import random
|
3 |
import torch
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
from torch import nn
|
6 |
from torch.nn import functional as F
|
7 |
|
|
|
9 |
|
10 |
|
11 |
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
12 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
13 |
+
|
14 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
15 |
|
16 |
Args:
|
17 |
out_size (int): The spatial size of outputs.
|
18 |
num_style_feat (int): Channel number of style features. Default: 512.
|
19 |
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
20 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
21 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
22 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
23 |
"""
|
24 |
|
25 |
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
|
|
29 |
num_mlp=num_mlp,
|
30 |
channel_multiplier=channel_multiplier,
|
31 |
narrow=narrow)
|
|
|
32 |
self.sft_half = sft_half
|
33 |
|
34 |
def forward(self,
|
|
|
41 |
truncation_latent=None,
|
42 |
inject_index=None,
|
43 |
return_latents=False):
|
44 |
+
"""Forward function for StyleGAN2GeneratorCSFT.
|
45 |
|
46 |
Args:
|
47 |
styles (list[Tensor]): Sample codes of styles.
|
48 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
49 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
50 |
noise (Tensor | None): Input noise or None. Default: None.
|
51 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
52 |
+
truncation (float): The truncation ratio. Default: 1.
|
53 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
54 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
55 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
|
|
|
|
|
|
56 |
"""
|
57 |
# style codes -> latents with Style MLP layer
|
58 |
if not input_is_latent:
|
|
|
69 |
for style in styles:
|
70 |
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
71 |
styles = style_truncation
|
72 |
+
# get style latents with injection
|
73 |
if len(styles) == 1:
|
74 |
inject_index = self.num_latent
|
75 |
|
|
|
98 |
# the conditions may have fewer levels
|
99 |
if i < len(conditions):
|
100 |
# SFT part to combine the conditions
|
101 |
+
if self.sft_half: # only apply SFT to half of the channels
|
102 |
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
103 |
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
104 |
out = torch.cat([out_same, out_sft], dim=1)
|
105 |
+
else: # apply SFT to all the channels
|
106 |
out = out * conditions[i - 1] + conditions[i]
|
107 |
|
108 |
out = conv2(out, latent[:, i + 1], noise=noise2)
|
109 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
110 |
i += 2
|
111 |
|
112 |
image = skip
|
|
|
118 |
|
119 |
|
120 |
class ResBlock(nn.Module):
|
121 |
+
"""Residual block with bilinear upsampling/downsampling.
|
122 |
|
123 |
Args:
|
124 |
in_channels (int): Channel number of the input.
|
125 |
out_channels (int): Channel number of the output.
|
126 |
+
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
127 |
"""
|
128 |
|
129 |
def __init__(self, in_channels, out_channels, mode='down'):
|
|
|
149 |
return out
|
150 |
|
151 |
|
152 |
+
@ARCH_REGISTRY.register()
|
153 |
class GFPGANv1Clean(nn.Module):
|
154 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
155 |
+
|
156 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
157 |
+
|
158 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
out_size (int): The spatial size of outputs.
|
162 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
163 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
164 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
165 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
166 |
+
|
167 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
168 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
169 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
170 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
171 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
172 |
+
"""
|
173 |
|
174 |
def __init__(
|
175 |
self,
|
|
|
190 |
self.different_w = different_w
|
191 |
self.num_style_feat = num_style_feat
|
192 |
|
193 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
194 |
channels = {
|
195 |
'4': int(512 * unet_narrow),
|
196 |
'8': int(512 * unet_narrow),
|
|
|
238 |
|
239 |
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
240 |
|
241 |
+
# the decoder: stylegan2 generator with SFT modulations
|
242 |
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
243 |
out_size=out_size,
|
244 |
num_style_feat=num_style_feat,
|
|
|
247 |
narrow=narrow,
|
248 |
sft_half=sft_half)
|
249 |
|
250 |
+
# load pre-trained stylegan2 model if necessary
|
251 |
if decoder_load_path:
|
252 |
self.stylegan_decoder.load_state_dict(
|
253 |
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
254 |
+
# fix decoder without updating params
|
255 |
if fix_decoder:
|
256 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
257 |
param.requires_grad = False
|
258 |
|
259 |
+
# for SFT modulations (scale and shift)
|
260 |
self.condition_scale = nn.ModuleList()
|
261 |
self.condition_shift = nn.ModuleList()
|
262 |
for i in range(3, self.log_size + 1):
|
|
|
274 |
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
275 |
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
276 |
|
277 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
278 |
+
"""Forward function for GFPGANv1Clean.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
x (Tensor): Input images.
|
282 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
283 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
284 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
285 |
+
"""
|
286 |
conditions = []
|
287 |
unet_skips = []
|
288 |
out_rgbs = []
|
|
|
298 |
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
299 |
if self.different_w:
|
300 |
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
301 |
+
|
302 |
# decode
|
303 |
for i in range(self.log_size - 2):
|
304 |
# add unet skip
|
305 |
feat = feat + unet_skips[i]
|
306 |
# ResUpLayer
|
307 |
feat = self.conv_body_up[i](feat)
|
308 |
+
# generate scale and shift for SFT layers
|
309 |
scale = self.condition_scale[i](feat)
|
310 |
conditions.append(scale.clone())
|
311 |
shift = self.condition_shift[i](feat)
|
|
|
314 |
if return_rgb:
|
315 |
out_rgbs.append(self.toRGB[i](feat))
|
316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
# decoder
|
318 |
image, _ = self.stylegan_decoder([style_code],
|
319 |
conditions,
|
gfpgan/archs/stylegan2_clean_arch.py
CHANGED
@@ -31,12 +31,9 @@ class ModulatedConv2d(nn.Module):
|
|
31 |
out_channels (int): Channel number of the output.
|
32 |
kernel_size (int): Size of the convolving kernel.
|
33 |
num_style_feat (int): Channel number of style features.
|
34 |
-
demodulate (bool): Whether to demodulate in the conv layer.
|
35 |
-
|
36 |
-
|
37 |
-
Default: None.
|
38 |
-
eps (float): A value added to the denominator for numerical stability.
|
39 |
-
Default: 1e-8.
|
40 |
"""
|
41 |
|
42 |
def __init__(self,
|
@@ -87,6 +84,7 @@ class ModulatedConv2d(nn.Module):
|
|
87 |
|
88 |
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
89 |
|
|
|
90 |
if self.sample_mode == 'upsample':
|
91 |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
92 |
elif self.sample_mode == 'downsample':
|
@@ -101,14 +99,12 @@ class ModulatedConv2d(nn.Module):
|
|
101 |
return out
|
102 |
|
103 |
def __repr__(self):
|
104 |
-
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
105 |
-
f'
|
106 |
-
f'kernel_size={self.kernel_size}, '
|
107 |
-
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
108 |
|
109 |
|
110 |
class StyleConv(nn.Module):
|
111 |
-
"""Style conv.
|
112 |
|
113 |
Args:
|
114 |
in_channels (int): Channel number of the input.
|
@@ -116,8 +112,7 @@ class StyleConv(nn.Module):
|
|
116 |
kernel_size (int): Size of the convolving kernel.
|
117 |
num_style_feat (int): Channel number of style features.
|
118 |
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
119 |
-
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
120 |
-
Default: None.
|
121 |
"""
|
122 |
|
123 |
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
@@ -144,7 +139,7 @@ class StyleConv(nn.Module):
|
|
144 |
|
145 |
|
146 |
class ToRGB(nn.Module):
|
147 |
-
"""To RGB from features.
|
148 |
|
149 |
Args:
|
150 |
in_channels (int): Channel number of input.
|
@@ -204,8 +199,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
|
204 |
out_size (int): The spatial size of outputs.
|
205 |
num_style_feat (int): Channel number of style features. Default: 512.
|
206 |
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
207 |
-
channel_multiplier (int): Channel multiplier for large networks of
|
208 |
-
StyleGAN2. Default: 2.
|
209 |
narrow (float): Narrow ratio for channels. Default: 1.0.
|
210 |
"""
|
211 |
|
@@ -222,6 +216,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
|
222 |
# initialization
|
223 |
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
224 |
|
|
|
225 |
channels = {
|
226 |
'4': int(512 * narrow),
|
227 |
'8': int(512 * narrow),
|
@@ -309,21 +304,17 @@ class StyleGAN2GeneratorClean(nn.Module):
|
|
309 |
truncation_latent=None,
|
310 |
inject_index=None,
|
311 |
return_latents=False):
|
312 |
-
"""Forward function for
|
313 |
|
314 |
Args:
|
315 |
styles (list[Tensor]): Sample codes of styles.
|
316 |
-
input_is_latent (bool): Whether input is latent style.
|
317 |
-
Default: False.
|
318 |
noise (Tensor | None): Input noise or None. Default: None.
|
319 |
-
randomize_noise (bool): Randomize noise, used when 'noise' is
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
Default: None.
|
325 |
-
return_latents (bool): Whether to return style latents.
|
326 |
-
Default: False.
|
327 |
"""
|
328 |
# style codes -> latents with Style MLP layer
|
329 |
if not input_is_latent:
|
@@ -340,7 +331,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
|
340 |
for style in styles:
|
341 |
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
342 |
styles = style_truncation
|
343 |
-
# get style
|
344 |
if len(styles) == 1:
|
345 |
inject_index = self.num_latent
|
346 |
|
@@ -366,7 +357,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
|
366 |
noise[2::2], self.to_rgbs):
|
367 |
out = conv1(out, latent[:, i], noise=noise1)
|
368 |
out = conv2(out, latent[:, i + 1], noise=noise2)
|
369 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
370 |
i += 2
|
371 |
|
372 |
image = skip
|
|
|
31 |
out_channels (int): Channel number of the output.
|
32 |
kernel_size (int): Size of the convolving kernel.
|
33 |
num_style_feat (int): Channel number of style features.
|
34 |
+
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
35 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
36 |
+
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
|
|
|
|
|
|
37 |
"""
|
38 |
|
39 |
def __init__(self,
|
|
|
84 |
|
85 |
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
86 |
|
87 |
+
# upsample or downsample if necessary
|
88 |
if self.sample_mode == 'upsample':
|
89 |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
90 |
elif self.sample_mode == 'downsample':
|
|
|
99 |
return out
|
100 |
|
101 |
def __repr__(self):
|
102 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
103 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
|
|
|
|
104 |
|
105 |
|
106 |
class StyleConv(nn.Module):
|
107 |
+
"""Style conv used in StyleGAN2.
|
108 |
|
109 |
Args:
|
110 |
in_channels (int): Channel number of the input.
|
|
|
112 |
kernel_size (int): Size of the convolving kernel.
|
113 |
num_style_feat (int): Channel number of style features.
|
114 |
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
115 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
|
|
116 |
"""
|
117 |
|
118 |
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
|
|
139 |
|
140 |
|
141 |
class ToRGB(nn.Module):
|
142 |
+
"""To RGB (image space) from features.
|
143 |
|
144 |
Args:
|
145 |
in_channels (int): Channel number of input.
|
|
|
199 |
out_size (int): The spatial size of outputs.
|
200 |
num_style_feat (int): Channel number of style features. Default: 512.
|
201 |
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
202 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
|
203 |
narrow (float): Narrow ratio for channels. Default: 1.0.
|
204 |
"""
|
205 |
|
|
|
216 |
# initialization
|
217 |
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
218 |
|
219 |
+
# channel list
|
220 |
channels = {
|
221 |
'4': int(512 * narrow),
|
222 |
'8': int(512 * narrow),
|
|
|
304 |
truncation_latent=None,
|
305 |
inject_index=None,
|
306 |
return_latents=False):
|
307 |
+
"""Forward function for StyleGAN2GeneratorClean.
|
308 |
|
309 |
Args:
|
310 |
styles (list[Tensor]): Sample codes of styles.
|
311 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
|
312 |
noise (Tensor | None): Input noise or None. Default: None.
|
313 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
314 |
+
truncation (float): The truncation ratio. Default: 1.
|
315 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
316 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
317 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
|
|
|
|
|
|
318 |
"""
|
319 |
# style codes -> latents with Style MLP layer
|
320 |
if not input_is_latent:
|
|
|
331 |
for style in styles:
|
332 |
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
333 |
styles = style_truncation
|
334 |
+
# get style latents with injection
|
335 |
if len(styles) == 1:
|
336 |
inject_index = self.num_latent
|
337 |
|
|
|
357 |
noise[2::2], self.to_rgbs):
|
358 |
out = conv1(out, latent[:, i], noise=noise1)
|
359 |
out = conv2(out, latent[:, i + 1], noise=noise2)
|
360 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
361 |
i += 2
|
362 |
|
363 |
image = skip
|
gfpgan/data/ffhq_degradation_dataset.py
CHANGED
@@ -15,6 +15,19 @@ from torchvision.transforms.functional import (adjust_brightness, adjust_contras
|
|
15 |
|
16 |
@DATASET_REGISTRY.register()
|
17 |
class FFHQDegradationDataset(data.Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def __init__(self, opt):
|
20 |
super(FFHQDegradationDataset, self).__init__()
|
@@ -29,11 +42,13 @@ class FFHQDegradationDataset(data.Dataset):
|
|
29 |
self.out_size = opt['out_size']
|
30 |
|
31 |
self.crop_components = opt.get('crop_components', False) # facial components
|
32 |
-
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
|
33 |
|
34 |
if self.crop_components:
|
|
|
35 |
self.components_list = torch.load(opt.get('component_path'))
|
36 |
|
|
|
37 |
if self.io_backend_opt['type'] == 'lmdb':
|
38 |
self.io_backend_opt['db_paths'] = self.gt_folder
|
39 |
if not self.gt_folder.endswith('.lmdb'):
|
@@ -41,9 +56,10 @@ class FFHQDegradationDataset(data.Dataset):
|
|
41 |
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
42 |
self.paths = [line.split('.')[0] for line in fin]
|
43 |
else:
|
|
|
44 |
self.paths = paths_from_folder(self.gt_folder)
|
45 |
|
46 |
-
#
|
47 |
self.blur_kernel_size = opt['blur_kernel_size']
|
48 |
self.kernel_list = opt['kernel_list']
|
49 |
self.kernel_prob = opt['kernel_prob']
|
@@ -60,22 +76,20 @@ class FFHQDegradationDataset(data.Dataset):
|
|
60 |
self.gray_prob = opt.get('gray_prob')
|
61 |
|
62 |
logger = get_root_logger()
|
63 |
-
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
|
64 |
-
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
65 |
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
66 |
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
67 |
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
68 |
|
69 |
if self.color_jitter_prob is not None:
|
70 |
-
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
|
71 |
-
f'shift: {self.color_jitter_shift}')
|
72 |
if self.gray_prob is not None:
|
73 |
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
74 |
-
|
75 |
self.color_jitter_shift /= 255.
|
76 |
|
77 |
@staticmethod
|
78 |
def color_jitter(img, shift):
|
|
|
79 |
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
80 |
img = img + jitter_val
|
81 |
img = np.clip(img, 0, 1)
|
@@ -83,6 +97,7 @@ class FFHQDegradationDataset(data.Dataset):
|
|
83 |
|
84 |
@staticmethod
|
85 |
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
|
|
86 |
fn_idx = torch.randperm(4)
|
87 |
for fn_id in fn_idx:
|
88 |
if fn_id == 0 and brightness is not None:
|
@@ -103,6 +118,7 @@ class FFHQDegradationDataset(data.Dataset):
|
|
103 |
return img
|
104 |
|
105 |
def get_component_coordinates(self, index, status):
|
|
|
106 |
components_bbox = self.components_list[f'{index:08d}']
|
107 |
if status[0]: # hflip
|
108 |
# exchange right and left eye
|
@@ -131,6 +147,7 @@ class FFHQDegradationDataset(data.Dataset):
|
|
131 |
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
132 |
|
133 |
# load gt image
|
|
|
134 |
gt_path = self.paths[index]
|
135 |
img_bytes = self.file_client.get(gt_path)
|
136 |
img_gt = imfrombytes(img_bytes, float32=True)
|
@@ -139,6 +156,7 @@ class FFHQDegradationDataset(data.Dataset):
|
|
139 |
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
140 |
h, w, _ = img_gt.shape
|
141 |
|
|
|
142 |
if self.crop_components:
|
143 |
locations = self.get_component_coordinates(index, status)
|
144 |
loc_left_eye, loc_right_eye, loc_mouth = locations
|
@@ -173,9 +191,9 @@ class FFHQDegradationDataset(data.Dataset):
|
|
173 |
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
174 |
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
175 |
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
176 |
-
if self.opt.get('gt_gray'):
|
177 |
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
178 |
-
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
179 |
|
180 |
# BGR to RGB, HWC to CHW, numpy to tensor
|
181 |
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
|
|
15 |
|
16 |
@DATASET_REGISTRY.register()
|
17 |
class FFHQDegradationDataset(data.Dataset):
|
18 |
+
"""FFHQ dataset for GFPGAN.
|
19 |
+
|
20 |
+
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
24 |
+
dataroot_gt (str): Data root path for gt.
|
25 |
+
io_backend (dict): IO backend type and other kwarg.
|
26 |
+
mean (list | tuple): Image mean.
|
27 |
+
std (list | tuple): Image std.
|
28 |
+
use_hflip (bool): Whether to horizontally flip.
|
29 |
+
Please see more options in the codes.
|
30 |
+
"""
|
31 |
|
32 |
def __init__(self, opt):
|
33 |
super(FFHQDegradationDataset, self).__init__()
|
|
|
42 |
self.out_size = opt['out_size']
|
43 |
|
44 |
self.crop_components = opt.get('crop_components', False) # facial components
|
45 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
46 |
|
47 |
if self.crop_components:
|
48 |
+
# load component list from a pre-process pth files
|
49 |
self.components_list = torch.load(opt.get('component_path'))
|
50 |
|
51 |
+
# file client (lmdb io backend)
|
52 |
if self.io_backend_opt['type'] == 'lmdb':
|
53 |
self.io_backend_opt['db_paths'] = self.gt_folder
|
54 |
if not self.gt_folder.endswith('.lmdb'):
|
|
|
56 |
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
57 |
self.paths = [line.split('.')[0] for line in fin]
|
58 |
else:
|
59 |
+
# disk backend: scan file list from a folder
|
60 |
self.paths = paths_from_folder(self.gt_folder)
|
61 |
|
62 |
+
# degradation configurations
|
63 |
self.blur_kernel_size = opt['blur_kernel_size']
|
64 |
self.kernel_list = opt['kernel_list']
|
65 |
self.kernel_prob = opt['kernel_prob']
|
|
|
76 |
self.gray_prob = opt.get('gray_prob')
|
77 |
|
78 |
logger = get_root_logger()
|
79 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
|
|
80 |
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
81 |
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
82 |
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
83 |
|
84 |
if self.color_jitter_prob is not None:
|
85 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
|
|
86 |
if self.gray_prob is not None:
|
87 |
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
|
|
88 |
self.color_jitter_shift /= 255.
|
89 |
|
90 |
@staticmethod
|
91 |
def color_jitter(img, shift):
|
92 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
93 |
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
94 |
img = img + jitter_val
|
95 |
img = np.clip(img, 0, 1)
|
|
|
97 |
|
98 |
@staticmethod
|
99 |
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
100 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
101 |
fn_idx = torch.randperm(4)
|
102 |
for fn_id in fn_idx:
|
103 |
if fn_id == 0 and brightness is not None:
|
|
|
118 |
return img
|
119 |
|
120 |
def get_component_coordinates(self, index, status):
|
121 |
+
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
122 |
components_bbox = self.components_list[f'{index:08d}']
|
123 |
if status[0]: # hflip
|
124 |
# exchange right and left eye
|
|
|
147 |
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
148 |
|
149 |
# load gt image
|
150 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
151 |
gt_path = self.paths[index]
|
152 |
img_bytes = self.file_client.get(gt_path)
|
153 |
img_gt = imfrombytes(img_bytes, float32=True)
|
|
|
156 |
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
157 |
h, w, _ = img_gt.shape
|
158 |
|
159 |
+
# get facial component coordinates
|
160 |
if self.crop_components:
|
161 |
locations = self.get_component_coordinates(index, status)
|
162 |
loc_left_eye, loc_right_eye, loc_mouth = locations
|
|
|
191 |
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
192 |
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
193 |
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
194 |
+
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
195 |
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
196 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
197 |
|
198 |
# BGR to RGB, HWC to CHW, numpy to tensor
|
199 |
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
gfpgan/models/gfpgan_model.py
CHANGED
@@ -16,11 +16,11 @@ from tqdm import tqdm
|
|
16 |
|
17 |
@MODEL_REGISTRY.register()
|
18 |
class GFPGANModel(BaseModel):
|
19 |
-
"""GFPGAN model for
|
20 |
|
21 |
def __init__(self, opt):
|
22 |
super(GFPGANModel, self).__init__(opt)
|
23 |
-
self.idx = 0
|
24 |
|
25 |
# define network
|
26 |
self.net_g = build_network(opt['network_g'])
|
@@ -51,8 +51,7 @@ class GFPGANModel(BaseModel):
|
|
51 |
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
52 |
|
53 |
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
54 |
-
# net_g_ema only used for testing on one GPU and saving
|
55 |
-
# There is no need to wrap with DistributedDataParallel
|
56 |
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
57 |
# load pretrained model
|
58 |
load_path = self.opt['path'].get('pretrain_network_g', None)
|
@@ -65,7 +64,7 @@ class GFPGANModel(BaseModel):
|
|
65 |
self.net_d.train()
|
66 |
self.net_g_ema.eval()
|
67 |
|
68 |
-
# ----------- facial
|
69 |
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
70 |
self.use_facial_disc = True
|
71 |
else:
|
@@ -102,17 +101,19 @@ class GFPGANModel(BaseModel):
|
|
102 |
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
103 |
|
104 |
# ----------- define losses ----------- #
|
|
|
105 |
if train_opt.get('pixel_opt'):
|
106 |
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
107 |
else:
|
108 |
self.cri_pix = None
|
109 |
|
|
|
110 |
if train_opt.get('perceptual_opt'):
|
111 |
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
112 |
else:
|
113 |
self.cri_perceptual = None
|
114 |
|
115 |
-
# L1 loss used in pyramid loss, component style loss and identity loss
|
116 |
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
117 |
|
118 |
# gan loss (wgan)
|
@@ -179,6 +180,7 @@ class GFPGANModel(BaseModel):
|
|
179 |
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
180 |
self.optimizers.append(self.optimizer_d)
|
181 |
|
|
|
182 |
if self.use_facial_disc:
|
183 |
# setup optimizers for facial component discriminators
|
184 |
optim_type = train_opt['optim_component'].pop('type')
|
@@ -221,6 +223,7 @@ class GFPGANModel(BaseModel):
|
|
221 |
# self.idx = self.idx + 1
|
222 |
|
223 |
def construct_img_pyramid(self):
|
|
|
224 |
pyramid_gt = [self.gt]
|
225 |
down_img = self.gt
|
226 |
for _ in range(0, self.log_size - 3):
|
@@ -229,7 +232,6 @@ class GFPGANModel(BaseModel):
|
|
229 |
return pyramid_gt
|
230 |
|
231 |
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
232 |
-
# hard code
|
233 |
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
234 |
eye_out_size *= face_ratio
|
235 |
mouth_out_size *= face_ratio
|
@@ -288,6 +290,7 @@ class GFPGANModel(BaseModel):
|
|
288 |
p.requires_grad = False
|
289 |
self.optimizer_g.zero_grad()
|
290 |
|
|
|
291 |
if self.use_facial_disc:
|
292 |
for p in self.net_d_left_eye.parameters():
|
293 |
p.requires_grad = False
|
@@ -419,11 +422,12 @@ class GFPGANModel(BaseModel):
|
|
419 |
real_d_pred = self.net_d(self.gt)
|
420 |
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
421 |
loss_dict['l_d'] = l_d
|
422 |
-
# In
|
423 |
loss_dict['real_score'] = real_d_pred.detach().mean()
|
424 |
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
425 |
l_d.backward()
|
426 |
|
|
|
427 |
if current_iter % self.net_d_reg_every == 0:
|
428 |
self.gt.requires_grad = True
|
429 |
real_pred = self.net_d(self.gt)
|
@@ -434,8 +438,9 @@ class GFPGANModel(BaseModel):
|
|
434 |
|
435 |
self.optimizer_d.step()
|
436 |
|
|
|
437 |
if self.use_facial_disc:
|
438 |
-
#
|
439 |
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
440 |
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
441 |
l_d_left_eye = self.cri_component(
|
@@ -485,22 +490,32 @@ class GFPGANModel(BaseModel):
|
|
485 |
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
486 |
dataset_name = dataloader.dataset.opt['name']
|
487 |
with_metrics = self.opt['val'].get('metrics') is not None
|
|
|
|
|
488 |
if with_metrics:
|
489 |
-
self
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
for idx, val_data in enumerate(dataloader):
|
493 |
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
494 |
self.feed_data(val_data)
|
495 |
self.test()
|
496 |
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
503 |
del self.gt
|
|
|
504 |
# tentative for out of GPU memory
|
505 |
del self.lq
|
506 |
del self.output
|
@@ -522,35 +537,38 @@ class GFPGANModel(BaseModel):
|
|
522 |
if with_metrics:
|
523 |
# calculate metrics
|
524 |
for name, opt_ in self.opt['val']['metrics'].items():
|
525 |
-
metric_data = dict(img1=sr_img, img2=gt_img)
|
526 |
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
527 |
-
|
528 |
-
|
529 |
-
|
|
|
|
|
530 |
|
531 |
if with_metrics:
|
532 |
for metric in self.metric_results.keys():
|
533 |
self.metric_results[metric] /= (idx + 1)
|
|
|
|
|
534 |
|
535 |
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
536 |
|
537 |
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
538 |
log_str = f'Validation {dataset_name}\n'
|
539 |
for metric, value in self.metric_results.items():
|
540 |
-
log_str += f'\t # {metric}: {value:.4f}
|
|
|
|
|
|
|
|
|
|
|
541 |
logger = get_root_logger()
|
542 |
logger.info(log_str)
|
543 |
if tb_logger:
|
544 |
for metric, value in self.metric_results.items():
|
545 |
-
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
546 |
-
|
547 |
-
def get_current_visuals(self):
|
548 |
-
out_dict = OrderedDict()
|
549 |
-
out_dict['gt'] = self.gt.detach().cpu()
|
550 |
-
out_dict['sr'] = self.output.detach().cpu()
|
551 |
-
return out_dict
|
552 |
|
553 |
def save(self, epoch, current_iter):
|
|
|
554 |
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
555 |
self.save_network(self.net_d, 'net_d', current_iter)
|
556 |
# save component discriminators
|
@@ -558,4 +576,5 @@ class GFPGANModel(BaseModel):
|
|
558 |
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
559 |
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
560 |
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
|
|
561 |
self.save_training_state(epoch, current_iter)
|
|
|
16 |
|
17 |
@MODEL_REGISTRY.register()
|
18 |
class GFPGANModel(BaseModel):
|
19 |
+
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
20 |
|
21 |
def __init__(self, opt):
|
22 |
super(GFPGANModel, self).__init__(opt)
|
23 |
+
self.idx = 0 # it is used for saving data for check
|
24 |
|
25 |
# define network
|
26 |
self.net_g = build_network(opt['network_g'])
|
|
|
51 |
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
52 |
|
53 |
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
54 |
+
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
|
|
55 |
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
56 |
# load pretrained model
|
57 |
load_path = self.opt['path'].get('pretrain_network_g', None)
|
|
|
64 |
self.net_d.train()
|
65 |
self.net_g_ema.eval()
|
66 |
|
67 |
+
# ----------- facial component networks ----------- #
|
68 |
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
69 |
self.use_facial_disc = True
|
70 |
else:
|
|
|
101 |
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
102 |
|
103 |
# ----------- define losses ----------- #
|
104 |
+
# pixel loss
|
105 |
if train_opt.get('pixel_opt'):
|
106 |
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
107 |
else:
|
108 |
self.cri_pix = None
|
109 |
|
110 |
+
# perceptual loss
|
111 |
if train_opt.get('perceptual_opt'):
|
112 |
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
113 |
else:
|
114 |
self.cri_perceptual = None
|
115 |
|
116 |
+
# L1 loss is used in pyramid loss, component style loss and identity loss
|
117 |
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
118 |
|
119 |
# gan loss (wgan)
|
|
|
180 |
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
181 |
self.optimizers.append(self.optimizer_d)
|
182 |
|
183 |
+
# ----------- optimizers for facial component networks ----------- #
|
184 |
if self.use_facial_disc:
|
185 |
# setup optimizers for facial component discriminators
|
186 |
optim_type = train_opt['optim_component'].pop('type')
|
|
|
223 |
# self.idx = self.idx + 1
|
224 |
|
225 |
def construct_img_pyramid(self):
|
226 |
+
"""Construct image pyramid for intermediate restoration loss"""
|
227 |
pyramid_gt = [self.gt]
|
228 |
down_img = self.gt
|
229 |
for _ in range(0, self.log_size - 3):
|
|
|
232 |
return pyramid_gt
|
233 |
|
234 |
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
|
|
235 |
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
236 |
eye_out_size *= face_ratio
|
237 |
mouth_out_size *= face_ratio
|
|
|
290 |
p.requires_grad = False
|
291 |
self.optimizer_g.zero_grad()
|
292 |
|
293 |
+
# do not update facial component net_d
|
294 |
if self.use_facial_disc:
|
295 |
for p in self.net_d_left_eye.parameters():
|
296 |
p.requires_grad = False
|
|
|
422 |
real_d_pred = self.net_d(self.gt)
|
423 |
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
424 |
loss_dict['l_d'] = l_d
|
425 |
+
# In WGAN, real_score should be positive and fake_score should be negative
|
426 |
loss_dict['real_score'] = real_d_pred.detach().mean()
|
427 |
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
428 |
l_d.backward()
|
429 |
|
430 |
+
# regularization loss
|
431 |
if current_iter % self.net_d_reg_every == 0:
|
432 |
self.gt.requires_grad = True
|
433 |
real_pred = self.net_d(self.gt)
|
|
|
438 |
|
439 |
self.optimizer_d.step()
|
440 |
|
441 |
+
# optimize facial component discriminators
|
442 |
if self.use_facial_disc:
|
443 |
+
# left eye
|
444 |
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
445 |
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
446 |
l_d_left_eye = self.cri_component(
|
|
|
490 |
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
491 |
dataset_name = dataloader.dataset.opt['name']
|
492 |
with_metrics = self.opt['val'].get('metrics') is not None
|
493 |
+
use_pbar = self.opt['val'].get('pbar', False)
|
494 |
+
|
495 |
if with_metrics:
|
496 |
+
if not hasattr(self, 'metric_results'): # only execute in the first run
|
497 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
498 |
+
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
499 |
+
self._initialize_best_metric_results(dataset_name)
|
500 |
+
# zero self.metric_results
|
501 |
+
self.metric_results = {metric: 0 for metric in self.metric_results}
|
502 |
+
|
503 |
+
metric_data = dict()
|
504 |
+
if use_pbar:
|
505 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
506 |
|
507 |
for idx, val_data in enumerate(dataloader):
|
508 |
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
509 |
self.feed_data(val_data)
|
510 |
self.test()
|
511 |
|
512 |
+
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
513 |
+
metric_data['img'] = sr_img
|
514 |
+
if hasattr(self, 'gt'):
|
515 |
+
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
516 |
+
metric_data['img2'] = gt_img
|
|
|
517 |
del self.gt
|
518 |
+
|
519 |
# tentative for out of GPU memory
|
520 |
del self.lq
|
521 |
del self.output
|
|
|
537 |
if with_metrics:
|
538 |
# calculate metrics
|
539 |
for name, opt_ in self.opt['val']['metrics'].items():
|
|
|
540 |
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
541 |
+
if use_pbar:
|
542 |
+
pbar.update(1)
|
543 |
+
pbar.set_description(f'Test {img_name}')
|
544 |
+
if use_pbar:
|
545 |
+
pbar.close()
|
546 |
|
547 |
if with_metrics:
|
548 |
for metric in self.metric_results.keys():
|
549 |
self.metric_results[metric] /= (idx + 1)
|
550 |
+
# update the best metric result
|
551 |
+
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
552 |
|
553 |
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
554 |
|
555 |
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
556 |
log_str = f'Validation {dataset_name}\n'
|
557 |
for metric, value in self.metric_results.items():
|
558 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
559 |
+
if hasattr(self, 'best_metric_results'):
|
560 |
+
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
561 |
+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
562 |
+
log_str += '\n'
|
563 |
+
|
564 |
logger = get_root_logger()
|
565 |
logger.info(log_str)
|
566 |
if tb_logger:
|
567 |
for metric, value in self.metric_results.items():
|
568 |
+
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
def save(self, epoch, current_iter):
|
571 |
+
# save net_g and net_d
|
572 |
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
573 |
self.save_network(self.net_d, 'net_d', current_iter)
|
574 |
# save component discriminators
|
|
|
576 |
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
577 |
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
578 |
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
579 |
+
# save training state
|
580 |
self.save_training_state(epoch, current_iter)
|
gfpgan/utils.py
CHANGED
@@ -2,10 +2,9 @@ import cv2
|
|
2 |
import os
|
3 |
import torch
|
4 |
from basicsr.utils import img2tensor, tensor2img
|
|
|
5 |
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
6 |
-
from torch.hub import download_url_to_file, get_dir
|
7 |
from torchvision.transforms.functional import normalize
|
8 |
-
from urllib.parse import urlparse
|
9 |
|
10 |
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
11 |
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
@@ -14,6 +13,20 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
14 |
|
15 |
|
16 |
class GFPGANer():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
19 |
self.upscale = upscale
|
@@ -56,7 +69,8 @@ class GFPGANer():
|
|
56 |
device=self.device)
|
57 |
|
58 |
if model_path.startswith('https://'):
|
59 |
-
model_path = load_file_from_url(
|
|
|
60 |
loadnet = torch.load(model_path)
|
61 |
if 'params_ema' in loadnet:
|
62 |
keyname = 'params_ema'
|
@@ -70,13 +84,15 @@ class GFPGANer():
|
|
70 |
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
71 |
self.face_helper.clean_all()
|
72 |
|
73 |
-
if has_aligned:
|
74 |
img = cv2.resize(img, (512, 512))
|
75 |
self.face_helper.cropped_faces = [img]
|
76 |
else:
|
77 |
self.face_helper.read_image(img)
|
78 |
# get face landmarks for each face
|
79 |
-
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face)
|
|
|
|
|
80 |
# align and warp each face
|
81 |
self.face_helper.align_warp_face()
|
82 |
|
@@ -99,9 +115,9 @@ class GFPGANer():
|
|
99 |
self.face_helper.add_restored_face(restored_face)
|
100 |
|
101 |
if not has_aligned and paste_back:
|
102 |
-
|
103 |
if self.bg_upsampler is not None:
|
104 |
-
# Now only support RealESRGAN
|
105 |
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
106 |
else:
|
107 |
bg_img = None
|
@@ -112,23 +128,3 @@ class GFPGANer():
|
|
112 |
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
113 |
else:
|
114 |
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
115 |
-
|
116 |
-
|
117 |
-
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
118 |
-
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
119 |
-
"""
|
120 |
-
if model_dir is None:
|
121 |
-
hub_dir = get_dir()
|
122 |
-
model_dir = os.path.join(hub_dir, 'checkpoints')
|
123 |
-
|
124 |
-
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
125 |
-
|
126 |
-
parts = urlparse(url)
|
127 |
-
filename = os.path.basename(parts.path)
|
128 |
-
if file_name is not None:
|
129 |
-
filename = file_name
|
130 |
-
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
131 |
-
if not os.path.exists(cached_file):
|
132 |
-
print(f'Downloading: "{url}" to {cached_file}\n')
|
133 |
-
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
134 |
-
return cached_file
|
|
|
2 |
import os
|
3 |
import torch
|
4 |
from basicsr.utils import img2tensor, tensor2img
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
|
|
7 |
from torchvision.transforms.functional import normalize
|
|
|
8 |
|
9 |
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
10 |
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
|
|
13 |
|
14 |
|
15 |
class GFPGANer():
|
16 |
+
"""Helper for restoration with GFPGAN.
|
17 |
+
|
18 |
+
It will detect and crop faces, and then resize the faces to 512x512.
|
19 |
+
GFPGAN is used to restored the resized faces.
|
20 |
+
The background is upsampled with the bg_upsampler.
|
21 |
+
Finally, the faces will be pasted back to the upsample background image.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
25 |
+
upscale (float): The upscale of the final output. Default: 2.
|
26 |
+
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
27 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
28 |
+
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
29 |
+
"""
|
30 |
|
31 |
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
32 |
self.upscale = upscale
|
|
|
69 |
device=self.device)
|
70 |
|
71 |
if model_path.startswith('https://'):
|
72 |
+
model_path = load_file_from_url(
|
73 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
74 |
loadnet = torch.load(model_path)
|
75 |
if 'params_ema' in loadnet:
|
76 |
keyname = 'params_ema'
|
|
|
84 |
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
85 |
self.face_helper.clean_all()
|
86 |
|
87 |
+
if has_aligned: # the inputs are already aligned
|
88 |
img = cv2.resize(img, (512, 512))
|
89 |
self.face_helper.cropped_faces = [img]
|
90 |
else:
|
91 |
self.face_helper.read_image(img)
|
92 |
# get face landmarks for each face
|
93 |
+
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
94 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
95 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
96 |
# align and warp each face
|
97 |
self.face_helper.align_warp_face()
|
98 |
|
|
|
115 |
self.face_helper.add_restored_face(restored_face)
|
116 |
|
117 |
if not has_aligned and paste_back:
|
118 |
+
# upsample the background
|
119 |
if self.bg_upsampler is not None:
|
120 |
+
# Now only support RealESRGAN for upsampling background
|
121 |
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
122 |
else:
|
123 |
bg_img = None
|
|
|
128 |
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
129 |
else:
|
130 |
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_gfpgan.py
CHANGED
@@ -10,20 +10,28 @@ from gfpgan import GFPGANer
|
|
10 |
|
11 |
|
12 |
def main():
|
|
|
|
|
13 |
parser = argparse.ArgumentParser()
|
14 |
-
|
15 |
-
parser.add_argument('--
|
16 |
-
parser.add_argument('--
|
17 |
-
parser.add_argument('--
|
18 |
-
parser.add_argument('--
|
19 |
-
parser.add_argument(
|
20 |
-
|
21 |
-
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
22 |
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
23 |
-
parser.add_argument('--only_center_face', action='store_true')
|
24 |
-
parser.add_argument('--aligned', action='store_true')
|
25 |
-
parser.add_argument('--paste_back', action='store_false')
|
26 |
-
parser.add_argument('--save_root', type=str, default='results')
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
args = parser.parse_args()
|
29 |
if args.test_path.endswith('/'):
|
@@ -38,10 +46,13 @@ def main():
|
|
38 |
'If you really want to use it, please modify the corresponding codes.')
|
39 |
bg_upsampler = None
|
40 |
else:
|
|
|
41 |
from realesrgan import RealESRGANer
|
|
|
42 |
bg_upsampler = RealESRGANer(
|
43 |
scale=2,
|
44 |
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
|
|
45 |
tile=args.bg_tile,
|
46 |
tile_pad=10,
|
47 |
pre_pad=0,
|
@@ -64,31 +75,38 @@ def main():
|
|
64 |
basename, ext = os.path.splitext(img_name)
|
65 |
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
66 |
|
|
|
67 |
cropped_faces, restored_faces, restored_img = restorer.enhance(
|
68 |
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
|
69 |
|
70 |
# save faces
|
71 |
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
|
72 |
# save cropped face
|
73 |
-
save_crop_path = os.path.join(args.save_root, f'{basename}_{idx:02d}.png')
|
74 |
imwrite(cropped_face, save_crop_path)
|
75 |
# save restored face
|
76 |
if args.suffix is not None:
|
77 |
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
|
78 |
else:
|
79 |
save_face_name = f'{basename}_{idx:02d}.png'
|
80 |
-
save_restore_path = os.path.join(args.save_root, save_face_name)
|
81 |
imwrite(restored_face, save_restore_path)
|
82 |
-
# save
|
83 |
-
|
84 |
-
|
85 |
|
86 |
# save restored img
|
87 |
if restored_img is not None:
|
|
|
|
|
|
|
|
|
|
|
88 |
if args.suffix is not None:
|
89 |
-
save_restore_path = os.path.join(args.save_root,
|
|
|
90 |
else:
|
91 |
-
save_restore_path = os.path.join(args.save_root,
|
92 |
imwrite(restored_img, save_restore_path)
|
93 |
|
94 |
print(f'Results are in the [{args.save_root}] folder.')
|
|
|
10 |
|
11 |
|
12 |
def main():
|
13 |
+
"""Inference demo for GFPGAN.
|
14 |
+
"""
|
15 |
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
|
17 |
+
parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
|
18 |
+
parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
|
19 |
+
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
|
20 |
+
parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
|
21 |
+
parser.add_argument(
|
22 |
+
'--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
|
23 |
+
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
|
24 |
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
25 |
+
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
|
26 |
+
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
|
27 |
+
parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
|
28 |
+
parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
|
29 |
+
parser.add_argument(
|
30 |
+
'--ext',
|
31 |
+
type=str,
|
32 |
+
default='auto',
|
33 |
+
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
34 |
+
args = parser.parse_args()
|
35 |
|
36 |
args = parser.parse_args()
|
37 |
if args.test_path.endswith('/'):
|
|
|
46 |
'If you really want to use it, please modify the corresponding codes.')
|
47 |
bg_upsampler = None
|
48 |
else:
|
49 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
50 |
from realesrgan import RealESRGANer
|
51 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
52 |
bg_upsampler = RealESRGANer(
|
53 |
scale=2,
|
54 |
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
55 |
+
model=model,
|
56 |
tile=args.bg_tile,
|
57 |
tile_pad=10,
|
58 |
pre_pad=0,
|
|
|
75 |
basename, ext = os.path.splitext(img_name)
|
76 |
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
77 |
|
78 |
+
# restore faces and background if necessary
|
79 |
cropped_faces, restored_faces, restored_img = restorer.enhance(
|
80 |
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
|
81 |
|
82 |
# save faces
|
83 |
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
|
84 |
# save cropped face
|
85 |
+
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
|
86 |
imwrite(cropped_face, save_crop_path)
|
87 |
# save restored face
|
88 |
if args.suffix is not None:
|
89 |
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
|
90 |
else:
|
91 |
save_face_name = f'{basename}_{idx:02d}.png'
|
92 |
+
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
|
93 |
imwrite(restored_face, save_restore_path)
|
94 |
+
# save comparison image
|
95 |
+
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
96 |
+
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
97 |
|
98 |
# save restored img
|
99 |
if restored_img is not None:
|
100 |
+
if args.ext == 'auto':
|
101 |
+
extension = ext[1:]
|
102 |
+
else:
|
103 |
+
extension = args.ext
|
104 |
+
|
105 |
if args.suffix is not None:
|
106 |
+
save_restore_path = os.path.join(args.save_root, 'restored_imgs',
|
107 |
+
f'{basename}_{args.suffix}.{extension}')
|
108 |
else:
|
109 |
+
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
|
110 |
imwrite(restored_img, save_restore_path)
|
111 |
|
112 |
print(f'Results are in the [{args.save_root}] folder.')
|
options/train_gfpgan_v1.yml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# general settings
|
2 |
name: train_GFPGANv1_512
|
3 |
model_type: GFPGANModel
|
4 |
-
num_gpu: 4
|
5 |
manual_seed: 0
|
6 |
|
7 |
# dataset and data loader settings
|
@@ -194,7 +194,7 @@ val:
|
|
194 |
save_img: true
|
195 |
|
196 |
metrics:
|
197 |
-
psnr: # metric name
|
198 |
type: calculate_psnr
|
199 |
crop_border: 0
|
200 |
test_y_channel: false
|
|
|
1 |
# general settings
|
2 |
name: train_GFPGANv1_512
|
3 |
model_type: GFPGANModel
|
4 |
+
num_gpu: auto # officially, we use 4 GPUs
|
5 |
manual_seed: 0
|
6 |
|
7 |
# dataset and data loader settings
|
|
|
194 |
save_img: true
|
195 |
|
196 |
metrics:
|
197 |
+
psnr: # metric name
|
198 |
type: calculate_psnr
|
199 |
crop_border: 0
|
200 |
test_y_channel: false
|
options/train_gfpgan_v1_simple.yml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# general settings
|
2 |
name: train_GFPGANv1_512_simple
|
3 |
model_type: GFPGANModel
|
4 |
-
num_gpu: 4
|
5 |
manual_seed: 0
|
6 |
|
7 |
# dataset and data loader settings
|
@@ -40,10 +40,6 @@ datasets:
|
|
40 |
# gray_prob: 0.01
|
41 |
# gt_gray: True
|
42 |
|
43 |
-
# crop_components: false
|
44 |
-
# component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
45 |
-
# eye_enlarge_ratio: 1.4
|
46 |
-
|
47 |
# data loader
|
48 |
use_shuffle: true
|
49 |
num_worker_per_gpu: 6
|
@@ -86,20 +82,6 @@ network_d:
|
|
86 |
channel_multiplier: 1
|
87 |
resample_kernel: [1, 3, 3, 1]
|
88 |
|
89 |
-
# network_d_left_eye:
|
90 |
-
# type: FacialComponentDiscriminator
|
91 |
-
|
92 |
-
# network_d_right_eye:
|
93 |
-
# type: FacialComponentDiscriminator
|
94 |
-
|
95 |
-
# network_d_mouth:
|
96 |
-
# type: FacialComponentDiscriminator
|
97 |
-
|
98 |
-
network_identity:
|
99 |
-
type: ResNetArcFace
|
100 |
-
block: IRBlock
|
101 |
-
layers: [2, 2, 2, 2]
|
102 |
-
use_se: False
|
103 |
|
104 |
# path
|
105 |
path:
|
@@ -107,13 +89,7 @@ path:
|
|
107 |
param_key_g: params_ema
|
108 |
strict_load_g: ~
|
109 |
pretrain_network_d: ~
|
110 |
-
# pretrain_network_d_left_eye: ~
|
111 |
-
# pretrain_network_d_right_eye: ~
|
112 |
-
# pretrain_network_d_mouth: ~
|
113 |
-
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
114 |
-
# resume
|
115 |
resume_state: ~
|
116 |
-
ignore_resume_networks: ['network_identity']
|
117 |
|
118 |
# training settings
|
119 |
train:
|
@@ -173,16 +149,6 @@ train:
|
|
173 |
loss_weight: !!float 1e-1
|
174 |
# r1 regularization for discriminator
|
175 |
r1_reg_weight: 10
|
176 |
-
# facial component loss
|
177 |
-
# gan_component_opt:
|
178 |
-
# type: GANLoss
|
179 |
-
# gan_type: vanilla
|
180 |
-
# real_label_val: 1.0
|
181 |
-
# fake_label_val: 0.0
|
182 |
-
# loss_weight: !!float 1
|
183 |
-
# comp_style_weight: 200
|
184 |
-
# identity loss
|
185 |
-
identity_weight: 10
|
186 |
|
187 |
net_d_iters: 1
|
188 |
net_d_init_iters: 0
|
@@ -194,7 +160,7 @@ val:
|
|
194 |
save_img: true
|
195 |
|
196 |
metrics:
|
197 |
-
psnr: # metric name
|
198 |
type: calculate_psnr
|
199 |
crop_border: 0
|
200 |
test_y_channel: false
|
|
|
1 |
# general settings
|
2 |
name: train_GFPGANv1_512_simple
|
3 |
model_type: GFPGANModel
|
4 |
+
num_gpu: auto # officially, we use 4 GPUs
|
5 |
manual_seed: 0
|
6 |
|
7 |
# dataset and data loader settings
|
|
|
40 |
# gray_prob: 0.01
|
41 |
# gt_gray: True
|
42 |
|
|
|
|
|
|
|
|
|
43 |
# data loader
|
44 |
use_shuffle: true
|
45 |
num_worker_per_gpu: 6
|
|
|
82 |
channel_multiplier: 1
|
83 |
resample_kernel: [1, 3, 3, 1]
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# path
|
87 |
path:
|
|
|
89 |
param_key_g: params_ema
|
90 |
strict_load_g: ~
|
91 |
pretrain_network_d: ~
|
|
|
|
|
|
|
|
|
|
|
92 |
resume_state: ~
|
|
|
93 |
|
94 |
# training settings
|
95 |
train:
|
|
|
149 |
loss_weight: !!float 1e-1
|
150 |
# r1 regularization for discriminator
|
151 |
r1_reg_weight: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
net_d_iters: 1
|
154 |
net_d_init_iters: 0
|
|
|
160 |
save_img: true
|
161 |
|
162 |
metrics:
|
163 |
+
psnr: # metric name
|
164 |
type: calculate_psnr
|
165 |
crop_border: 0
|
166 |
test_y_channel: false
|
requirements.txt
CHANGED
@@ -1,18 +1,12 @@
|
|
1 |
-
torch
|
2 |
-
numpy
|
3 |
-
opencv-python
|
4 |
-
setuptools
|
5 |
-
Pillow
|
6 |
-
gradio
|
7 |
torchvision
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
lmdb
|
11 |
pyyaml
|
12 |
-
requests
|
13 |
-
scikit-image
|
14 |
-
scipy
|
15 |
tb-nightly
|
16 |
-
tqdm
|
17 |
yapf
|
18 |
-
psutil
|
|
|
1 |
+
torch>=1.7
|
2 |
+
numpy<1.21 # numba requires numpy<1.21,>=1.17
|
3 |
+
opencv-python
|
|
|
|
|
|
|
4 |
torchvision
|
5 |
+
scipy
|
6 |
+
tqdm
|
7 |
+
basicsr>=1.3.4.0
|
8 |
+
facexlib>=0.2.0.3
|
9 |
lmdb
|
10 |
pyyaml
|
|
|
|
|
|
|
11 |
tb-nightly
|
|
|
12 |
yapf
|
|
scripts/parse_landmark.py
CHANGED
@@ -1,24 +1,31 @@
|
|
1 |
import cv2
|
2 |
import json
|
3 |
import numpy as np
|
|
|
4 |
import torch
|
5 |
from basicsr.utils import FileClient, imfrombytes
|
6 |
from collections import OrderedDict
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
print('Load JSON metadata...')
|
9 |
-
# use the json file in FFHQ dataset
|
10 |
-
with open(
|
11 |
json_data = json.load(f, object_pairs_hook=OrderedDict)
|
12 |
|
13 |
print('Open LMDB file...')
|
14 |
# read ffhq images
|
15 |
-
file_client = FileClient('lmdb', db_paths=
|
16 |
-
with open('
|
17 |
paths = [line.split('.')[0] for line in fin]
|
18 |
|
19 |
-
save_img = False
|
20 |
-
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
21 |
-
enlarge_ratio = 1.4 # only for eyes
|
22 |
save_dict = {}
|
23 |
|
24 |
for item_idx, item in enumerate(json_data.values()):
|
@@ -34,6 +41,7 @@ for item_idx, item in enumerate(json_data.values()):
|
|
34 |
img_bytes = file_client.get(paths[item_idx])
|
35 |
img = imfrombytes(img_bytes, float32=True)
|
36 |
|
|
|
37 |
map_left_eye = list(range(36, 42))
|
38 |
map_right_eye = list(range(42, 48))
|
39 |
map_mouth = list(range(48, 68))
|
@@ -74,4 +82,4 @@ for item_idx, item in enumerate(json_data.values()):
|
|
74 |
save_dict[f'{item_idx:08d}'] = item_dict
|
75 |
|
76 |
print('Save...')
|
77 |
-
torch.save(save_dict,
|
|
|
1 |
import cv2
|
2 |
import json
|
3 |
import numpy as np
|
4 |
+
import os
|
5 |
import torch
|
6 |
from basicsr.utils import FileClient, imfrombytes
|
7 |
from collections import OrderedDict
|
8 |
|
9 |
+
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
|
10 |
+
# Configurations
|
11 |
+
save_img = False
|
12 |
+
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
13 |
+
enlarge_ratio = 1.4 # only for eyes
|
14 |
+
json_path = 'ffhq-dataset-v2.json'
|
15 |
+
face_path = 'datasets/ffhq/ffhq_512.lmdb'
|
16 |
+
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
|
17 |
+
|
18 |
print('Load JSON metadata...')
|
19 |
+
# use the official json file in FFHQ dataset
|
20 |
+
with open(json_path, 'rb') as f:
|
21 |
json_data = json.load(f, object_pairs_hook=OrderedDict)
|
22 |
|
23 |
print('Open LMDB file...')
|
24 |
# read ffhq images
|
25 |
+
file_client = FileClient('lmdb', db_paths=face_path)
|
26 |
+
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
|
27 |
paths = [line.split('.')[0] for line in fin]
|
28 |
|
|
|
|
|
|
|
29 |
save_dict = {}
|
30 |
|
31 |
for item_idx, item in enumerate(json_data.values()):
|
|
|
41 |
img_bytes = file_client.get(paths[item_idx])
|
42 |
img = imfrombytes(img_bytes, float32=True)
|
43 |
|
44 |
+
# get landmarks for each component
|
45 |
map_left_eye = list(range(36, 42))
|
46 |
map_right_eye = list(range(42, 48))
|
47 |
map_mouth = list(range(48, 68))
|
|
|
82 |
save_dict[f'{item_idx:08d}'] = item_dict
|
83 |
|
84 |
print('Save...')
|
85 |
+
torch.save(save_dict, save_path)
|
setup.cfg
CHANGED
@@ -17,6 +17,17 @@ line_length = 120
|
|
17 |
multi_line_output = 0
|
18 |
known_standard_library = pkg_resources,setuptools
|
19 |
known_first_party = gfpgan
|
20 |
-
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
|
21 |
no_lines_before = STDLIB,LOCALFOLDER
|
22 |
default_section = THIRDPARTY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
multi_line_output = 0
|
18 |
known_standard_library = pkg_resources,setuptools
|
19 |
known_first_party = gfpgan
|
20 |
+
known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
|
21 |
no_lines_before = STDLIB,LOCALFOLDER
|
22 |
default_section = THIRDPARTY
|
23 |
+
|
24 |
+
[codespell]
|
25 |
+
skip = .git,./docs/build
|
26 |
+
count =
|
27 |
+
quiet-level = 3
|
28 |
+
|
29 |
+
[aliases]
|
30 |
+
test=pytest
|
31 |
+
|
32 |
+
[tool:pytest]
|
33 |
+
addopts=tests/
|
setup.py
CHANGED
@@ -43,12 +43,6 @@ def get_git_hash():
|
|
43 |
def get_hash():
|
44 |
if os.path.exists('.git'):
|
45 |
sha = get_git_hash()[:7]
|
46 |
-
elif os.path.exists(version_file):
|
47 |
-
try:
|
48 |
-
from facexlib.version import __version__
|
49 |
-
sha = __version__.split('+')[-1]
|
50 |
-
except ImportError:
|
51 |
-
raise ImportError('Unable to get git version')
|
52 |
else:
|
53 |
sha = 'unknown'
|
54 |
|
|
|
43 |
def get_hash():
|
44 |
if os.path.exists('.git'):
|
45 |
sha = get_git_hash()[:7]
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
else:
|
47 |
sha = 'unknown'
|
48 |
|
tests/data/ffhq_gt.lmdb/data.mdb
ADDED
Binary file (455 kB). View file
|
|
experiments/.DS_Store → tests/data/ffhq_gt.lmdb/lock.mdb
RENAMED
Binary files a/experiments/.DS_Store and b/tests/data/ffhq_gt.lmdb/lock.mdb differ
|
|
tests/data/ffhq_gt.lmdb/meta_info.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
00000000.png (512,512,3) 1
|
tests/data/gt/00000000.png
ADDED
tests/data/test_eye_mouth_landmarks.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:131583fca2cc346652f8754eb3c5a0bdeda808686039ff10ead7a26254b72358
|
3 |
+
size 943
|
tests/data/test_ffhq_degradation_dataset.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: UnitTest
|
2 |
+
type: FFHQDegradationDataset
|
3 |
+
dataroot_gt: tests/data/gt
|
4 |
+
io_backend:
|
5 |
+
type: disk
|
6 |
+
|
7 |
+
use_hflip: true
|
8 |
+
mean: [0.5, 0.5, 0.5]
|
9 |
+
std: [0.5, 0.5, 0.5]
|
10 |
+
out_size: 512
|
11 |
+
|
12 |
+
blur_kernel_size: 41
|
13 |
+
kernel_list: ['iso', 'aniso']
|
14 |
+
kernel_prob: [0.5, 0.5]
|
15 |
+
blur_sigma: [0.1, 10]
|
16 |
+
downsample_range: [0.8, 8]
|
17 |
+
noise_range: [0, 20]
|
18 |
+
jpeg_range: [60, 100]
|
19 |
+
|
20 |
+
# color jitter and gray
|
21 |
+
color_jitter_prob: 1
|
22 |
+
color_jitter_shift: 20
|
23 |
+
color_jitter_pt_prob: 1
|
24 |
+
gray_prob: 1
|
tests/data/test_gfpgan_model.yml
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_gpu: 1
|
2 |
+
manual_seed: 0
|
3 |
+
is_train: True
|
4 |
+
dist: False
|
5 |
+
|
6 |
+
# network structures
|
7 |
+
network_g:
|
8 |
+
type: GFPGANv1
|
9 |
+
out_size: 512
|
10 |
+
num_style_feat: 512
|
11 |
+
channel_multiplier: 1
|
12 |
+
resample_kernel: [1, 3, 3, 1]
|
13 |
+
decoder_load_path: ~
|
14 |
+
fix_decoder: true
|
15 |
+
num_mlp: 8
|
16 |
+
lr_mlp: 0.01
|
17 |
+
input_is_latent: true
|
18 |
+
different_w: true
|
19 |
+
narrow: 0.5
|
20 |
+
sft_half: true
|
21 |
+
|
22 |
+
network_d:
|
23 |
+
type: StyleGAN2Discriminator
|
24 |
+
out_size: 512
|
25 |
+
channel_multiplier: 1
|
26 |
+
resample_kernel: [1, 3, 3, 1]
|
27 |
+
|
28 |
+
network_d_left_eye:
|
29 |
+
type: FacialComponentDiscriminator
|
30 |
+
|
31 |
+
network_d_right_eye:
|
32 |
+
type: FacialComponentDiscriminator
|
33 |
+
|
34 |
+
network_d_mouth:
|
35 |
+
type: FacialComponentDiscriminator
|
36 |
+
|
37 |
+
network_identity:
|
38 |
+
type: ResNetArcFace
|
39 |
+
block: IRBlock
|
40 |
+
layers: [2, 2, 2, 2]
|
41 |
+
use_se: False
|
42 |
+
|
43 |
+
# path
|
44 |
+
path:
|
45 |
+
pretrain_network_g: ~
|
46 |
+
param_key_g: params_ema
|
47 |
+
strict_load_g: ~
|
48 |
+
pretrain_network_d: ~
|
49 |
+
pretrain_network_d_left_eye: ~
|
50 |
+
pretrain_network_d_right_eye: ~
|
51 |
+
pretrain_network_d_mouth: ~
|
52 |
+
pretrain_network_identity: ~
|
53 |
+
# resume
|
54 |
+
resume_state: ~
|
55 |
+
ignore_resume_networks: ['network_identity']
|
56 |
+
|
57 |
+
# training settings
|
58 |
+
train:
|
59 |
+
optim_g:
|
60 |
+
type: Adam
|
61 |
+
lr: !!float 2e-3
|
62 |
+
optim_d:
|
63 |
+
type: Adam
|
64 |
+
lr: !!float 2e-3
|
65 |
+
optim_component:
|
66 |
+
type: Adam
|
67 |
+
lr: !!float 2e-3
|
68 |
+
|
69 |
+
scheduler:
|
70 |
+
type: MultiStepLR
|
71 |
+
milestones: [600000, 700000]
|
72 |
+
gamma: 0.5
|
73 |
+
|
74 |
+
total_iter: 800000
|
75 |
+
warmup_iter: -1 # no warm up
|
76 |
+
|
77 |
+
# losses
|
78 |
+
# pixel loss
|
79 |
+
pixel_opt:
|
80 |
+
type: L1Loss
|
81 |
+
loss_weight: !!float 1e-1
|
82 |
+
reduction: mean
|
83 |
+
# L1 loss used in pyramid loss, component style loss and identity loss
|
84 |
+
L1_opt:
|
85 |
+
type: L1Loss
|
86 |
+
loss_weight: 1
|
87 |
+
reduction: mean
|
88 |
+
|
89 |
+
# image pyramid loss
|
90 |
+
pyramid_loss_weight: 1
|
91 |
+
remove_pyramid_loss: 50000
|
92 |
+
# perceptual loss (content and style losses)
|
93 |
+
perceptual_opt:
|
94 |
+
type: PerceptualLoss
|
95 |
+
layer_weights:
|
96 |
+
# before relu
|
97 |
+
'conv1_2': 0.1
|
98 |
+
'conv2_2': 0.1
|
99 |
+
'conv3_4': 1
|
100 |
+
'conv4_4': 1
|
101 |
+
'conv5_4': 1
|
102 |
+
vgg_type: vgg19
|
103 |
+
use_input_norm: true
|
104 |
+
perceptual_weight: !!float 1
|
105 |
+
style_weight: 50
|
106 |
+
range_norm: true
|
107 |
+
criterion: l1
|
108 |
+
# gan loss
|
109 |
+
gan_opt:
|
110 |
+
type: GANLoss
|
111 |
+
gan_type: wgan_softplus
|
112 |
+
loss_weight: !!float 1e-1
|
113 |
+
# r1 regularization for discriminator
|
114 |
+
r1_reg_weight: 10
|
115 |
+
# facial component loss
|
116 |
+
gan_component_opt:
|
117 |
+
type: GANLoss
|
118 |
+
gan_type: vanilla
|
119 |
+
real_label_val: 1.0
|
120 |
+
fake_label_val: 0.0
|
121 |
+
loss_weight: !!float 1
|
122 |
+
comp_style_weight: 200
|
123 |
+
# identity loss
|
124 |
+
identity_weight: 10
|
125 |
+
|
126 |
+
net_d_iters: 1
|
127 |
+
net_d_init_iters: 0
|
128 |
+
net_d_reg_every: 1
|
129 |
+
|
130 |
+
# validation settings
|
131 |
+
val:
|
132 |
+
val_freq: !!float 5e3
|
133 |
+
save_img: True
|
134 |
+
use_pbar: True
|
135 |
+
|
136 |
+
metrics:
|
137 |
+
psnr: # metric name
|
138 |
+
type: calculate_psnr
|
139 |
+
crop_border: 0
|
140 |
+
test_y_channel: false
|
tests/test_arcface_arch.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
|
4 |
+
|
5 |
+
|
6 |
+
def test_resnetarcface():
|
7 |
+
"""Test arch: ResNetArcFace."""
|
8 |
+
|
9 |
+
# model init and forward (gpu)
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
|
12 |
+
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
|
13 |
+
output = net(img)
|
14 |
+
assert output.shape == (1, 512)
|
15 |
+
|
16 |
+
# -------------------- without SE block ----------------------- #
|
17 |
+
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
|
18 |
+
output = net(img)
|
19 |
+
assert output.shape == (1, 512)
|
20 |
+
|
21 |
+
|
22 |
+
def test_basicblock():
|
23 |
+
"""Test the BasicBlock in arcface_arch"""
|
24 |
+
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
|
25 |
+
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
26 |
+
output = block(img)
|
27 |
+
assert output.shape == (1, 3, 12, 12)
|
28 |
+
|
29 |
+
# ----------------- use the downsmaple module--------------- #
|
30 |
+
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
|
31 |
+
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
|
32 |
+
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
33 |
+
output = block(img)
|
34 |
+
assert output.shape == (1, 3, 6, 6)
|
35 |
+
|
36 |
+
|
37 |
+
def test_bottleneck():
|
38 |
+
"""Test the Bottleneck in arcface_arch"""
|
39 |
+
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
|
40 |
+
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
41 |
+
output = block(img)
|
42 |
+
assert output.shape == (1, 4, 12, 12)
|
43 |
+
|
44 |
+
# ----------------- use the downsmaple module--------------- #
|
45 |
+
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
|
46 |
+
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
|
47 |
+
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
48 |
+
output = block(img)
|
49 |
+
assert output.shape == (1, 4, 6, 6)
|
tests/test_ffhq_degradation_dataset.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import yaml
|
3 |
+
|
4 |
+
from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
|
5 |
+
|
6 |
+
|
7 |
+
def test_ffhq_degradation_dataset():
|
8 |
+
|
9 |
+
with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
|
10 |
+
opt = yaml.load(f, Loader=yaml.FullLoader)
|
11 |
+
|
12 |
+
dataset = FFHQDegradationDataset(opt)
|
13 |
+
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
14 |
+
assert len(dataset) == 1 # whether to read correct meta info
|
15 |
+
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
16 |
+
assert dataset.color_jitter_prob == 1
|
17 |
+
|
18 |
+
# test __getitem__
|
19 |
+
result = dataset.__getitem__(0)
|
20 |
+
# check returned keys
|
21 |
+
expected_keys = ['gt', 'lq', 'gt_path']
|
22 |
+
assert set(expected_keys).issubset(set(result.keys()))
|
23 |
+
# check shape and contents
|
24 |
+
assert result['gt'].shape == (3, 512, 512)
|
25 |
+
assert result['lq'].shape == (3, 512, 512)
|
26 |
+
assert result['gt_path'] == 'tests/data/gt/00000000.png'
|
27 |
+
|
28 |
+
# ------------------ test with probability = 0 -------------------- #
|
29 |
+
opt['color_jitter_prob'] = 0
|
30 |
+
opt['color_jitter_pt_prob'] = 0
|
31 |
+
opt['gray_prob'] = 0
|
32 |
+
opt['io_backend'] = dict(type='disk')
|
33 |
+
dataset = FFHQDegradationDataset(opt)
|
34 |
+
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
35 |
+
assert len(dataset) == 1 # whether to read correct meta info
|
36 |
+
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
37 |
+
assert dataset.color_jitter_prob == 0
|
38 |
+
|
39 |
+
# test __getitem__
|
40 |
+
result = dataset.__getitem__(0)
|
41 |
+
# check returned keys
|
42 |
+
expected_keys = ['gt', 'lq', 'gt_path']
|
43 |
+
assert set(expected_keys).issubset(set(result.keys()))
|
44 |
+
# check shape and contents
|
45 |
+
assert result['gt'].shape == (3, 512, 512)
|
46 |
+
assert result['lq'].shape == (3, 512, 512)
|
47 |
+
assert result['gt_path'] == 'tests/data/gt/00000000.png'
|
48 |
+
|
49 |
+
# ------------------ test lmdb backend -------------------- #
|
50 |
+
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
|
51 |
+
opt['io_backend'] = dict(type='lmdb')
|
52 |
+
|
53 |
+
dataset = FFHQDegradationDataset(opt)
|
54 |
+
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
55 |
+
assert len(dataset) == 1 # whether to read correct meta info
|
56 |
+
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
57 |
+
assert dataset.color_jitter_prob == 0
|
58 |
+
|
59 |
+
# test __getitem__
|
60 |
+
result = dataset.__getitem__(0)
|
61 |
+
# check returned keys
|
62 |
+
expected_keys = ['gt', 'lq', 'gt_path']
|
63 |
+
assert set(expected_keys).issubset(set(result.keys()))
|
64 |
+
# check shape and contents
|
65 |
+
assert result['gt'].shape == (3, 512, 512)
|
66 |
+
assert result['lq'].shape == (3, 512, 512)
|
67 |
+
assert result['gt_path'] == '00000000'
|
68 |
+
|
69 |
+
# ------------------ test with crop_components -------------------- #
|
70 |
+
opt['crop_components'] = True
|
71 |
+
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
|
72 |
+
opt['eye_enlarge_ratio'] = 1.4
|
73 |
+
opt['gt_gray'] = True
|
74 |
+
opt['io_backend'] = dict(type='lmdb')
|
75 |
+
|
76 |
+
dataset = FFHQDegradationDataset(opt)
|
77 |
+
assert dataset.crop_components is True
|
78 |
+
|
79 |
+
# test __getitem__
|
80 |
+
result = dataset.__getitem__(0)
|
81 |
+
# check returned keys
|
82 |
+
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
|
83 |
+
assert set(expected_keys).issubset(set(result.keys()))
|
84 |
+
# check shape and contents
|
85 |
+
assert result['gt'].shape == (3, 512, 512)
|
86 |
+
assert result['lq'].shape == (3, 512, 512)
|
87 |
+
assert result['gt_path'] == '00000000'
|
88 |
+
assert result['loc_left_eye'].shape == (4, )
|
89 |
+
assert result['loc_right_eye'].shape == (4, )
|
90 |
+
assert result['loc_mouth'].shape == (4, )
|
91 |
+
|
92 |
+
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
|
93 |
+
with pytest.raises(ValueError):
|
94 |
+
opt['dataroot_gt'] = 'tests/data/gt'
|
95 |
+
opt['io_backend'] = dict(type='lmdb')
|
96 |
+
dataset = FFHQDegradationDataset(opt)
|
tests/test_gfpgan_arch.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
|
4 |
+
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
|
5 |
+
|
6 |
+
|
7 |
+
def test_stylegan2generatorsft():
|
8 |
+
"""Test arch: StyleGAN2GeneratorSFT."""
|
9 |
+
|
10 |
+
# model init and forward (gpu)
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
net = StyleGAN2GeneratorSFT(
|
13 |
+
out_size=32,
|
14 |
+
num_style_feat=512,
|
15 |
+
num_mlp=8,
|
16 |
+
channel_multiplier=1,
|
17 |
+
resample_kernel=(1, 3, 3, 1),
|
18 |
+
lr_mlp=0.01,
|
19 |
+
narrow=1,
|
20 |
+
sft_half=False).cuda().eval()
|
21 |
+
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
22 |
+
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
|
23 |
+
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
|
24 |
+
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
|
25 |
+
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
|
26 |
+
output = net([style], conditions)
|
27 |
+
assert output[0].shape == (1, 3, 32, 32)
|
28 |
+
assert output[1] is None
|
29 |
+
|
30 |
+
# -------------------- with return_latents ----------------------- #
|
31 |
+
output = net([style], conditions, return_latents=True)
|
32 |
+
assert output[0].shape == (1, 3, 32, 32)
|
33 |
+
assert len(output[1]) == 1
|
34 |
+
# check latent
|
35 |
+
assert output[1][0].shape == (8, 512)
|
36 |
+
|
37 |
+
# -------------------- with randomize_noise = False ----------------------- #
|
38 |
+
output = net([style], conditions, randomize_noise=False)
|
39 |
+
assert output[0].shape == (1, 3, 32, 32)
|
40 |
+
assert output[1] is None
|
41 |
+
|
42 |
+
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
43 |
+
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
|
44 |
+
assert output[0].shape == (1, 3, 32, 32)
|
45 |
+
assert output[1] is None
|
46 |
+
|
47 |
+
|
48 |
+
def test_gfpganv1():
|
49 |
+
"""Test arch: GFPGANv1."""
|
50 |
+
|
51 |
+
# model init and forward (gpu)
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
net = GFPGANv1(
|
54 |
+
out_size=32,
|
55 |
+
num_style_feat=512,
|
56 |
+
channel_multiplier=1,
|
57 |
+
resample_kernel=(1, 3, 3, 1),
|
58 |
+
decoder_load_path=None,
|
59 |
+
fix_decoder=True,
|
60 |
+
# for stylegan decoder
|
61 |
+
num_mlp=8,
|
62 |
+
lr_mlp=0.01,
|
63 |
+
input_is_latent=False,
|
64 |
+
different_w=False,
|
65 |
+
narrow=1,
|
66 |
+
sft_half=True).cuda().eval()
|
67 |
+
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
68 |
+
output = net(img)
|
69 |
+
assert output[0].shape == (1, 3, 32, 32)
|
70 |
+
assert len(output[1]) == 3
|
71 |
+
# check out_rgbs for intermediate loss
|
72 |
+
assert output[1][0].shape == (1, 3, 8, 8)
|
73 |
+
assert output[1][1].shape == (1, 3, 16, 16)
|
74 |
+
assert output[1][2].shape == (1, 3, 32, 32)
|
75 |
+
|
76 |
+
# -------------------- with different_w = True ----------------------- #
|
77 |
+
net = GFPGANv1(
|
78 |
+
out_size=32,
|
79 |
+
num_style_feat=512,
|
80 |
+
channel_multiplier=1,
|
81 |
+
resample_kernel=(1, 3, 3, 1),
|
82 |
+
decoder_load_path=None,
|
83 |
+
fix_decoder=True,
|
84 |
+
# for stylegan decoder
|
85 |
+
num_mlp=8,
|
86 |
+
lr_mlp=0.01,
|
87 |
+
input_is_latent=False,
|
88 |
+
different_w=True,
|
89 |
+
narrow=1,
|
90 |
+
sft_half=True).cuda().eval()
|
91 |
+
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
92 |
+
output = net(img)
|
93 |
+
assert output[0].shape == (1, 3, 32, 32)
|
94 |
+
assert len(output[1]) == 3
|
95 |
+
# check out_rgbs for intermediate loss
|
96 |
+
assert output[1][0].shape == (1, 3, 8, 8)
|
97 |
+
assert output[1][1].shape == (1, 3, 16, 16)
|
98 |
+
assert output[1][2].shape == (1, 3, 32, 32)
|
99 |
+
|
100 |
+
|
101 |
+
def test_facialcomponentdiscriminator():
|
102 |
+
"""Test arch: FacialComponentDiscriminator."""
|
103 |
+
|
104 |
+
# model init and forward (gpu)
|
105 |
+
if torch.cuda.is_available():
|
106 |
+
net = FacialComponentDiscriminator().cuda().eval()
|
107 |
+
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
108 |
+
output = net(img)
|
109 |
+
assert len(output) == 2
|
110 |
+
assert output[0].shape == (1, 1, 8, 8)
|
111 |
+
assert output[1] is None
|
112 |
+
|
113 |
+
# -------------------- return intermediate features ----------------------- #
|
114 |
+
output = net(img, return_feats=True)
|
115 |
+
assert len(output) == 2
|
116 |
+
assert output[0].shape == (1, 1, 8, 8)
|
117 |
+
assert len(output[1]) == 2
|
118 |
+
assert output[1][0].shape == (1, 128, 16, 16)
|
119 |
+
assert output[1][1].shape == (1, 256, 8, 8)
|
120 |
+
|
121 |
+
|
122 |
+
def test_stylegan2generatorcsft():
|
123 |
+
"""Test arch: StyleGAN2GeneratorCSFT."""
|
124 |
+
|
125 |
+
# model init and forward (gpu)
|
126 |
+
if torch.cuda.is_available():
|
127 |
+
net = StyleGAN2GeneratorCSFT(
|
128 |
+
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
|
129 |
+
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
130 |
+
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
|
131 |
+
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
|
132 |
+
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
|
133 |
+
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
|
134 |
+
output = net([style], conditions)
|
135 |
+
assert output[0].shape == (1, 3, 32, 32)
|
136 |
+
assert output[1] is None
|
137 |
+
|
138 |
+
# -------------------- with return_latents ----------------------- #
|
139 |
+
output = net([style], conditions, return_latents=True)
|
140 |
+
assert output[0].shape == (1, 3, 32, 32)
|
141 |
+
assert len(output[1]) == 1
|
142 |
+
# check latent
|
143 |
+
assert output[1][0].shape == (8, 512)
|
144 |
+
|
145 |
+
# -------------------- with randomize_noise = False ----------------------- #
|
146 |
+
output = net([style], conditions, randomize_noise=False)
|
147 |
+
assert output[0].shape == (1, 3, 32, 32)
|
148 |
+
assert output[1] is None
|
149 |
+
|
150 |
+
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
151 |
+
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
|
152 |
+
assert output[0].shape == (1, 3, 32, 32)
|
153 |
+
assert output[1] is None
|
154 |
+
|
155 |
+
|
156 |
+
def test_gfpganv1clean():
|
157 |
+
"""Test arch: GFPGANv1Clean."""
|
158 |
+
|
159 |
+
# model init and forward (gpu)
|
160 |
+
if torch.cuda.is_available():
|
161 |
+
net = GFPGANv1Clean(
|
162 |
+
out_size=32,
|
163 |
+
num_style_feat=512,
|
164 |
+
channel_multiplier=1,
|
165 |
+
decoder_load_path=None,
|
166 |
+
fix_decoder=True,
|
167 |
+
# for stylegan decoder
|
168 |
+
num_mlp=8,
|
169 |
+
input_is_latent=False,
|
170 |
+
different_w=False,
|
171 |
+
narrow=1,
|
172 |
+
sft_half=True).cuda().eval()
|
173 |
+
|
174 |
+
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
175 |
+
output = net(img)
|
176 |
+
assert output[0].shape == (1, 3, 32, 32)
|
177 |
+
assert len(output[1]) == 3
|
178 |
+
# check out_rgbs for intermediate loss
|
179 |
+
assert output[1][0].shape == (1, 3, 8, 8)
|
180 |
+
assert output[1][1].shape == (1, 3, 16, 16)
|
181 |
+
assert output[1][2].shape == (1, 3, 32, 32)
|
182 |
+
|
183 |
+
# -------------------- with different_w = True ----------------------- #
|
184 |
+
net = GFPGANv1Clean(
|
185 |
+
out_size=32,
|
186 |
+
num_style_feat=512,
|
187 |
+
channel_multiplier=1,
|
188 |
+
decoder_load_path=None,
|
189 |
+
fix_decoder=True,
|
190 |
+
# for stylegan decoder
|
191 |
+
num_mlp=8,
|
192 |
+
input_is_latent=False,
|
193 |
+
different_w=True,
|
194 |
+
narrow=1,
|
195 |
+
sft_half=True).cuda().eval()
|
196 |
+
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
197 |
+
output = net(img)
|
198 |
+
assert output[0].shape == (1, 3, 32, 32)
|
199 |
+
assert len(output[1]) == 3
|
200 |
+
# check out_rgbs for intermediate loss
|
201 |
+
assert output[1][0].shape == (1, 3, 8, 8)
|
202 |
+
assert output[1][1].shape == (1, 3, 16, 16)
|
203 |
+
assert output[1][2].shape == (1, 3, 32, 32)
|
tests/test_gfpgan_model.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import torch
|
3 |
+
import yaml
|
4 |
+
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
|
5 |
+
from basicsr.data.paired_image_dataset import PairedImageDataset
|
6 |
+
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
7 |
+
|
8 |
+
from gfpgan.archs.arcface_arch import ResNetArcFace
|
9 |
+
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
|
10 |
+
from gfpgan.models.gfpgan_model import GFPGANModel
|
11 |
+
|
12 |
+
|
13 |
+
def test_gfpgan_model():
|
14 |
+
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
|
15 |
+
opt = yaml.load(f, Loader=yaml.FullLoader)
|
16 |
+
|
17 |
+
# build model
|
18 |
+
model = GFPGANModel(opt)
|
19 |
+
# test attributes
|
20 |
+
assert model.__class__.__name__ == 'GFPGANModel'
|
21 |
+
assert isinstance(model.net_g, GFPGANv1) # generator
|
22 |
+
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
|
23 |
+
# facial component discriminators
|
24 |
+
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
|
25 |
+
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
|
26 |
+
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
|
27 |
+
# identity network
|
28 |
+
assert isinstance(model.network_identity, ResNetArcFace)
|
29 |
+
# losses
|
30 |
+
assert isinstance(model.cri_pix, L1Loss)
|
31 |
+
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
32 |
+
assert isinstance(model.cri_gan, GANLoss)
|
33 |
+
assert isinstance(model.cri_l1, L1Loss)
|
34 |
+
# optimizer
|
35 |
+
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
36 |
+
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
37 |
+
|
38 |
+
# prepare data
|
39 |
+
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
40 |
+
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
41 |
+
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
|
42 |
+
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
|
43 |
+
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
|
44 |
+
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
|
45 |
+
model.feed_data(data)
|
46 |
+
# check data shape
|
47 |
+
assert model.lq.shape == (1, 3, 512, 512)
|
48 |
+
assert model.gt.shape == (1, 3, 512, 512)
|
49 |
+
assert model.loc_left_eyes.shape == (1, 4)
|
50 |
+
assert model.loc_right_eyes.shape == (1, 4)
|
51 |
+
assert model.loc_mouths.shape == (1, 4)
|
52 |
+
|
53 |
+
# ----------------- test optimize_parameters -------------------- #
|
54 |
+
model.feed_data(data)
|
55 |
+
model.optimize_parameters(1)
|
56 |
+
assert model.output.shape == (1, 3, 512, 512)
|
57 |
+
assert isinstance(model.log_dict, dict)
|
58 |
+
# check returned keys
|
59 |
+
expected_keys = [
|
60 |
+
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
61 |
+
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
62 |
+
'l_d_right_eye', 'l_d_mouth'
|
63 |
+
]
|
64 |
+
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
65 |
+
|
66 |
+
# ----------------- remove pyramid_loss_weight-------------------- #
|
67 |
+
model.feed_data(data)
|
68 |
+
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
|
69 |
+
assert model.output.shape == (1, 3, 512, 512)
|
70 |
+
assert isinstance(model.log_dict, dict)
|
71 |
+
# check returned keys
|
72 |
+
expected_keys = [
|
73 |
+
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
74 |
+
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
75 |
+
'l_d_right_eye', 'l_d_mouth'
|
76 |
+
]
|
77 |
+
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
78 |
+
|
79 |
+
# ----------------- test save -------------------- #
|
80 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
81 |
+
model.opt['path']['models'] = tmpdir
|
82 |
+
model.opt['path']['training_states'] = tmpdir
|
83 |
+
model.save(0, 1)
|
84 |
+
|
85 |
+
# ----------------- test the test function -------------------- #
|
86 |
+
model.test()
|
87 |
+
assert model.output.shape == (1, 3, 512, 512)
|
88 |
+
# delete net_g_ema
|
89 |
+
model.__delattr__('net_g_ema')
|
90 |
+
model.test()
|
91 |
+
assert model.output.shape == (1, 3, 512, 512)
|
92 |
+
assert model.net_g.training is True # should back to training mode after testing
|
93 |
+
|
94 |
+
# ----------------- test nondist_validation -------------------- #
|
95 |
+
# construct dataloader
|
96 |
+
dataset_opt = dict(
|
97 |
+
name='Demo',
|
98 |
+
dataroot_gt='tests/data/gt',
|
99 |
+
dataroot_lq='tests/data/gt',
|
100 |
+
io_backend=dict(type='disk'),
|
101 |
+
scale=4,
|
102 |
+
phase='val')
|
103 |
+
dataset = PairedImageDataset(dataset_opt)
|
104 |
+
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
105 |
+
assert model.is_train is True
|
106 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
107 |
+
model.opt['path']['visualization'] = tmpdir
|
108 |
+
model.nondist_validation(dataloader, 1, None, save_img=True)
|
109 |
+
assert model.is_train is True
|
110 |
+
# check metric_results
|
111 |
+
assert 'psnr' in model.metric_results
|
112 |
+
assert isinstance(model.metric_results['psnr'], float)
|
113 |
+
|
114 |
+
# validation
|
115 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
116 |
+
model.opt['is_train'] = False
|
117 |
+
model.opt['val']['suffix'] = 'test'
|
118 |
+
model.opt['path']['visualization'] = tmpdir
|
119 |
+
model.opt['val']['pbar'] = True
|
120 |
+
model.nondist_validation(dataloader, 1, None, save_img=True)
|
121 |
+
# check metric_results
|
122 |
+
assert 'psnr' in model.metric_results
|
123 |
+
assert isinstance(model.metric_results['psnr'], float)
|
124 |
+
|
125 |
+
# if opt['val']['suffix'] is None
|
126 |
+
model.opt['val']['suffix'] = None
|
127 |
+
model.opt['name'] = 'demo'
|
128 |
+
model.opt['path']['visualization'] = tmpdir
|
129 |
+
model.nondist_validation(dataloader, 1, None, save_img=True)
|
130 |
+
# check metric_results
|
131 |
+
assert 'psnr' in model.metric_results
|
132 |
+
assert isinstance(model.metric_results['psnr'], float)
|
tests/test_stylegan2_clean_arch.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
|
4 |
+
|
5 |
+
|
6 |
+
def test_stylegan2generatorclean():
|
7 |
+
"""Test arch: StyleGAN2GeneratorClean."""
|
8 |
+
|
9 |
+
# model init and forward (gpu)
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
net = StyleGAN2GeneratorClean(
|
12 |
+
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
|
13 |
+
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
14 |
+
output = net([style], input_is_latent=False)
|
15 |
+
assert output[0].shape == (1, 3, 32, 32)
|
16 |
+
assert output[1] is None
|
17 |
+
|
18 |
+
# -------------------- with return_latents ----------------------- #
|
19 |
+
output = net([style], input_is_latent=True, return_latents=True)
|
20 |
+
assert output[0].shape == (1, 3, 32, 32)
|
21 |
+
assert len(output[1]) == 1
|
22 |
+
# check latent
|
23 |
+
assert output[1][0].shape == (8, 512)
|
24 |
+
|
25 |
+
# -------------------- with randomize_noise = False ----------------------- #
|
26 |
+
output = net([style], randomize_noise=False)
|
27 |
+
assert output[0].shape == (1, 3, 32, 32)
|
28 |
+
assert output[1] is None
|
29 |
+
|
30 |
+
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
31 |
+
output = net([style, style], truncation=0.5, truncation_latent=style)
|
32 |
+
assert output[0].shape == (1, 3, 32, 32)
|
33 |
+
assert output[1] is None
|
34 |
+
|
35 |
+
# ------------------ test make_noise ----------------------- #
|
36 |
+
out = net.make_noise()
|
37 |
+
assert len(out) == 7
|
38 |
+
assert out[0].shape == (1, 1, 4, 4)
|
39 |
+
assert out[1].shape == (1, 1, 8, 8)
|
40 |
+
assert out[2].shape == (1, 1, 8, 8)
|
41 |
+
assert out[3].shape == (1, 1, 16, 16)
|
42 |
+
assert out[4].shape == (1, 1, 16, 16)
|
43 |
+
assert out[5].shape == (1, 1, 32, 32)
|
44 |
+
assert out[6].shape == (1, 1, 32, 32)
|
45 |
+
|
46 |
+
# ------------------ test get_latent ----------------------- #
|
47 |
+
out = net.get_latent(style)
|
48 |
+
assert out.shape == (1, 512)
|
49 |
+
|
50 |
+
# ------------------ test mean_latent ----------------------- #
|
51 |
+
out = net.mean_latent(2)
|
52 |
+
assert out.shape == (1, 512)
|
tests/test_utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
3 |
+
|
4 |
+
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
5 |
+
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
6 |
+
from gfpgan.utils import GFPGANer
|
7 |
+
|
8 |
+
|
9 |
+
def test_gfpganer():
|
10 |
+
# initialize with the clean model
|
11 |
+
restorer = GFPGANer(
|
12 |
+
model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
|
13 |
+
upscale=2,
|
14 |
+
arch='clean',
|
15 |
+
channel_multiplier=2,
|
16 |
+
bg_upsampler=None)
|
17 |
+
# test attribute
|
18 |
+
assert isinstance(restorer.gfpgan, GFPGANv1Clean)
|
19 |
+
assert isinstance(restorer.face_helper, FaceRestoreHelper)
|
20 |
+
|
21 |
+
# initialize with the original model
|
22 |
+
restorer = GFPGANer(
|
23 |
+
model_path='experiments/pretrained_models/GFPGANv1.pth',
|
24 |
+
upscale=2,
|
25 |
+
arch='original',
|
26 |
+
channel_multiplier=1,
|
27 |
+
bg_upsampler=None)
|
28 |
+
# test attribute
|
29 |
+
assert isinstance(restorer.gfpgan, GFPGANv1)
|
30 |
+
assert isinstance(restorer.face_helper, FaceRestoreHelper)
|
31 |
+
|
32 |
+
# ------------------ test enhance ---------------- #
|
33 |
+
img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
|
34 |
+
result = restorer.enhance(img, has_aligned=False, paste_back=True)
|
35 |
+
assert result[0][0].shape == (512, 512, 3)
|
36 |
+
assert result[1][0].shape == (512, 512, 3)
|
37 |
+
assert result[2].shape == (1024, 1024, 3)
|
38 |
+
|
39 |
+
# with has_aligned=True
|
40 |
+
result = restorer.enhance(img, has_aligned=True, paste_back=False)
|
41 |
+
assert result[0][0].shape == (512, 512, 3)
|
42 |
+
assert result[1][0].shape == (512, 512, 3)
|
43 |
+
assert result[2] is None
|