Jacobmadwed commited on
Commit
2702c71
1 Parent(s): f24cb6a

Upload 34 files

Browse files
Files changed (34) hide show
  1. gfpgan/.DS_Store +0 -0
  2. gfpgan/__init__.py +7 -0
  3. gfpgan/__pycache__/__init__.cpython-312.pyc +0 -0
  4. gfpgan/__pycache__/utils.cpython-312.pyc +0 -0
  5. gfpgan/archs/__init__.py +10 -0
  6. gfpgan/archs/__pycache__/__init__.cpython-312.pyc +0 -0
  7. gfpgan/archs/__pycache__/arcface_arch.cpython-312.pyc +0 -0
  8. gfpgan/archs/__pycache__/gfpgan_bilinear_arch.cpython-312.pyc +0 -0
  9. gfpgan/archs/__pycache__/gfpganv1_arch.cpython-312.pyc +0 -0
  10. gfpgan/archs/__pycache__/gfpganv1_clean_arch.cpython-312.pyc +0 -0
  11. gfpgan/archs/__pycache__/restoreformer_arch.cpython-312.pyc +0 -0
  12. gfpgan/archs/__pycache__/stylegan2_bilinear_arch.cpython-312.pyc +0 -0
  13. gfpgan/archs/__pycache__/stylegan2_clean_arch.cpython-312.pyc +0 -0
  14. gfpgan/archs/arcface_arch.py +245 -0
  15. gfpgan/archs/gfpgan_bilinear_arch.py +312 -0
  16. gfpgan/archs/gfpganv1_arch.py +439 -0
  17. gfpgan/archs/gfpganv1_clean_arch.py +324 -0
  18. gfpgan/archs/restoreformer_arch.py +658 -0
  19. gfpgan/archs/stylegan2_bilinear_arch.py +613 -0
  20. gfpgan/archs/stylegan2_clean_arch.py +368 -0
  21. gfpgan/data/__init__.py +10 -0
  22. gfpgan/data/__pycache__/__init__.cpython-312.pyc +0 -0
  23. gfpgan/data/__pycache__/ffhq_degradation_dataset.cpython-312.pyc +0 -0
  24. gfpgan/data/ffhq_degradation_dataset.py +230 -0
  25. gfpgan/models/__init__.py +10 -0
  26. gfpgan/models/__pycache__/__init__.cpython-312.pyc +0 -0
  27. gfpgan/models/__pycache__/gfpgan_model.cpython-312.pyc +0 -0
  28. gfpgan/models/gfpgan_model.py +579 -0
  29. gfpgan/train.py +11 -0
  30. gfpgan/utils.py +148 -0
  31. gfpgan/version.py +5 -0
  32. gfpgan/weights/README.md +3 -0
  33. gfpgan/weights/detection_Resnet50_Final.pth +3 -0
  34. gfpgan/weights/parsing_parsenet.pth +3 -0
gfpgan/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gfpgan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ from .archs import *
3
+ from .data import *
4
+ from .models import *
5
+ from .utils import *
6
+
7
+ # from .version import *
gfpgan/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (238 Bytes). View file
 
gfpgan/__pycache__/utils.cpython-312.pyc ADDED
Binary file (6.61 kB). View file
 
gfpgan/archs/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import arch modules for registry
6
+ # scan all the files that end with '_arch.py' under the archs folder
7
+ arch_folder = osp.dirname(osp.abspath(__file__))
8
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9
+ # import all the arch modules
10
+ _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
gfpgan/archs/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (884 Bytes). View file
 
gfpgan/archs/__pycache__/arcface_arch.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
gfpgan/archs/__pycache__/gfpgan_bilinear_arch.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
gfpgan/archs/__pycache__/gfpganv1_arch.cpython-312.pyc ADDED
Binary file (20 kB). View file
 
gfpgan/archs/__pycache__/gfpganv1_clean_arch.cpython-312.pyc ADDED
Binary file (15.6 kB). View file
 
gfpgan/archs/__pycache__/restoreformer_arch.cpython-312.pyc ADDED
Binary file (28.5 kB). View file
 
gfpgan/archs/__pycache__/stylegan2_bilinear_arch.cpython-312.pyc ADDED
Binary file (27.9 kB). View file
 
gfpgan/archs/__pycache__/stylegan2_clean_arch.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
gfpgan/archs/arcface_arch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from basicsr.utils.registry import ARCH_REGISTRY
3
+
4
+
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
+
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
+ super(IRBlock, self).__init__()
70
+ self.bn0 = nn.BatchNorm2d(inplanes)
71
+ self.conv1 = conv3x3(inplanes, inplanes)
72
+ self.bn1 = nn.BatchNorm2d(inplanes)
73
+ self.prelu = nn.PReLU()
74
+ self.conv2 = conv3x3(inplanes, planes, stride)
75
+ self.bn2 = nn.BatchNorm2d(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+ self.use_se = use_se
79
+ if self.use_se:
80
+ self.se = SEBlock(planes)
81
+
82
+ def forward(self, x):
83
+ residual = x
84
+ out = self.bn0(x)
85
+ out = self.conv1(out)
86
+ out = self.bn1(out)
87
+ out = self.prelu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+ if self.use_se:
92
+ out = self.se(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.prelu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
+
114
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
115
+ super(Bottleneck, self).__init__()
116
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(planes)
118
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(planes)
120
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.downsample = downsample
124
+ self.stride = stride
125
+
126
+ def forward(self, x):
127
+ residual = x
128
+
129
+ out = self.conv1(x)
130
+ out = self.bn1(out)
131
+ out = self.relu(out)
132
+
133
+ out = self.conv2(out)
134
+ out = self.bn2(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv3(out)
138
+ out = self.bn3(out)
139
+
140
+ if self.downsample is not None:
141
+ residual = self.downsample(x)
142
+
143
+ out += residual
144
+ out = self.relu(out)
145
+
146
+ return out
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
+
157
+ def __init__(self, channel, reduction=16):
158
+ super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
+ self.fc = nn.Sequential(
161
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
+ nn.Sigmoid())
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ @ARCH_REGISTRY.register()
172
+ class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
+
183
+ def __init__(self, block, layers, use_se=True):
184
+ if block == 'IRBlock':
185
+ block = IRBlock
186
+ self.inplanes = 64
187
+ self.use_se = use_se
188
+ super(ResNetArcFace, self).__init__()
189
+
190
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.prelu = nn.PReLU()
193
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.bn4 = nn.BatchNorm2d(512)
199
+ self.dropout = nn.Dropout()
200
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
+ self.bn5 = nn.BatchNorm1d(512)
202
+
203
+ # initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_normal_(m.weight)
207
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.xavier_normal_(m.weight)
212
+ nn.init.constant_(m.bias, 0)
213
+
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
+ downsample = None
216
+ if stride != 1 or self.inplanes != planes * block.expansion:
217
+ downsample = nn.Sequential(
218
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
+ nn.BatchNorm2d(planes * block.expansion),
220
+ )
221
+ layers = []
222
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
+ self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
+
227
+ return nn.Sequential(*layers)
228
+
229
+ def forward(self, x):
230
+ x = self.conv1(x)
231
+ x = self.bn1(x)
232
+ x = self.prelu(x)
233
+ x = self.maxpool(x)
234
+
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.bn4(x)
240
+ x = self.dropout(x)
241
+ x = x.view(x.size(0), -1)
242
+ x = self.fc5(x)
243
+ x = self.bn5(x)
244
+
245
+ return x
gfpgan/archs/gfpgan_bilinear_arch.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+ from torch import nn
6
+
7
+ from .gfpganv1_arch import ResUpBlock
8
+ from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
9
+ StyleGAN2GeneratorBilinear)
10
+
11
+
12
+ class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
13
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
+
15
+ It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
16
+ deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
17
+
18
+ Args:
19
+ out_size (int): The spatial size of outputs.
20
+ num_style_feat (int): Channel number of style features. Default: 512.
21
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
22
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
23
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
24
+ narrow (float): The narrow ratio for channels. Default: 1.
25
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
26
+ """
27
+
28
+ def __init__(self,
29
+ out_size,
30
+ num_style_feat=512,
31
+ num_mlp=8,
32
+ channel_multiplier=2,
33
+ lr_mlp=0.01,
34
+ narrow=1,
35
+ sft_half=False):
36
+ super(StyleGAN2GeneratorBilinearSFT, self).__init__(
37
+ out_size,
38
+ num_style_feat=num_style_feat,
39
+ num_mlp=num_mlp,
40
+ channel_multiplier=channel_multiplier,
41
+ lr_mlp=lr_mlp,
42
+ narrow=narrow)
43
+ self.sft_half = sft_half
44
+
45
+ def forward(self,
46
+ styles,
47
+ conditions,
48
+ input_is_latent=False,
49
+ noise=None,
50
+ randomize_noise=True,
51
+ truncation=1,
52
+ truncation_latent=None,
53
+ inject_index=None,
54
+ return_latents=False):
55
+ """Forward function for StyleGAN2GeneratorBilinearSFT.
56
+
57
+ Args:
58
+ styles (list[Tensor]): Sample codes of styles.
59
+ conditions (list[Tensor]): SFT conditions to generators.
60
+ input_is_latent (bool): Whether input is latent style. Default: False.
61
+ noise (Tensor | None): Input noise or None. Default: None.
62
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
63
+ truncation (float): The truncation ratio. Default: 1.
64
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
65
+ inject_index (int | None): The injection index for mixing noise. Default: None.
66
+ return_latents (bool): Whether to return style latents. Default: False.
67
+ """
68
+ # style codes -> latents with Style MLP layer
69
+ if not input_is_latent:
70
+ styles = [self.style_mlp(s) for s in styles]
71
+ # noises
72
+ if noise is None:
73
+ if randomize_noise:
74
+ noise = [None] * self.num_layers # for each style conv layer
75
+ else: # use the stored noise
76
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
77
+ # style truncation
78
+ if truncation < 1:
79
+ style_truncation = []
80
+ for style in styles:
81
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
82
+ styles = style_truncation
83
+ # get style latents with injection
84
+ if len(styles) == 1:
85
+ inject_index = self.num_latent
86
+
87
+ if styles[0].ndim < 3:
88
+ # repeat latent code for all the layers
89
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
90
+ else: # used for encoder with different latent code for each layer
91
+ latent = styles[0]
92
+ elif len(styles) == 2: # mixing noises
93
+ if inject_index is None:
94
+ inject_index = random.randint(1, self.num_latent - 1)
95
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
96
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
97
+ latent = torch.cat([latent1, latent2], 1)
98
+
99
+ # main generation
100
+ out = self.constant_input(latent.shape[0])
101
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
102
+ skip = self.to_rgb1(out, latent[:, 1])
103
+
104
+ i = 1
105
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
106
+ noise[2::2], self.to_rgbs):
107
+ out = conv1(out, latent[:, i], noise=noise1)
108
+
109
+ # the conditions may have fewer levels
110
+ if i < len(conditions):
111
+ # SFT part to combine the conditions
112
+ if self.sft_half: # only apply SFT to half of the channels
113
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
114
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
115
+ out = torch.cat([out_same, out_sft], dim=1)
116
+ else: # apply SFT to all the channels
117
+ out = out * conditions[i - 1] + conditions[i]
118
+
119
+ out = conv2(out, latent[:, i + 1], noise=noise2)
120
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
121
+ i += 2
122
+
123
+ image = skip
124
+
125
+ if return_latents:
126
+ return image, latent
127
+ else:
128
+ return image, None
129
+
130
+
131
+ @ARCH_REGISTRY.register()
132
+ class GFPGANBilinear(nn.Module):
133
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
134
+
135
+ It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
136
+ deployment. It can be easily converted to the clean version: GFPGANv1Clean.
137
+
138
+
139
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
140
+
141
+ Args:
142
+ out_size (int): The spatial size of outputs.
143
+ num_style_feat (int): Channel number of style features. Default: 512.
144
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
145
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
146
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
147
+
148
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
149
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
150
+ input_is_latent (bool): Whether input is latent style. Default: False.
151
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
152
+ narrow (float): The narrow ratio for channels. Default: 1.
153
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ out_size,
159
+ num_style_feat=512,
160
+ channel_multiplier=1,
161
+ decoder_load_path=None,
162
+ fix_decoder=True,
163
+ # for stylegan decoder
164
+ num_mlp=8,
165
+ lr_mlp=0.01,
166
+ input_is_latent=False,
167
+ different_w=False,
168
+ narrow=1,
169
+ sft_half=False):
170
+
171
+ super(GFPGANBilinear, self).__init__()
172
+ self.input_is_latent = input_is_latent
173
+ self.different_w = different_w
174
+ self.num_style_feat = num_style_feat
175
+
176
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
177
+ channels = {
178
+ '4': int(512 * unet_narrow),
179
+ '8': int(512 * unet_narrow),
180
+ '16': int(512 * unet_narrow),
181
+ '32': int(512 * unet_narrow),
182
+ '64': int(256 * channel_multiplier * unet_narrow),
183
+ '128': int(128 * channel_multiplier * unet_narrow),
184
+ '256': int(64 * channel_multiplier * unet_narrow),
185
+ '512': int(32 * channel_multiplier * unet_narrow),
186
+ '1024': int(16 * channel_multiplier * unet_narrow)
187
+ }
188
+
189
+ self.log_size = int(math.log(out_size, 2))
190
+ first_out_size = 2**(int(math.log(out_size, 2)))
191
+
192
+ self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
193
+
194
+ # downsample
195
+ in_channels = channels[f'{first_out_size}']
196
+ self.conv_body_down = nn.ModuleList()
197
+ for i in range(self.log_size, 2, -1):
198
+ out_channels = channels[f'{2**(i - 1)}']
199
+ self.conv_body_down.append(ResBlock(in_channels, out_channels))
200
+ in_channels = out_channels
201
+
202
+ self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
203
+
204
+ # upsample
205
+ in_channels = channels['4']
206
+ self.conv_body_up = nn.ModuleList()
207
+ for i in range(3, self.log_size + 1):
208
+ out_channels = channels[f'{2**i}']
209
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
210
+ in_channels = out_channels
211
+
212
+ # to RGB
213
+ self.toRGB = nn.ModuleList()
214
+ for i in range(3, self.log_size + 1):
215
+ self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
216
+
217
+ if different_w:
218
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
219
+ else:
220
+ linear_out_channel = num_style_feat
221
+
222
+ self.final_linear = EqualLinear(
223
+ channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
224
+
225
+ # the decoder: stylegan2 generator with SFT modulations
226
+ self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
227
+ out_size=out_size,
228
+ num_style_feat=num_style_feat,
229
+ num_mlp=num_mlp,
230
+ channel_multiplier=channel_multiplier,
231
+ lr_mlp=lr_mlp,
232
+ narrow=narrow,
233
+ sft_half=sft_half)
234
+
235
+ # load pre-trained stylegan2 model if necessary
236
+ if decoder_load_path:
237
+ self.stylegan_decoder.load_state_dict(
238
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
239
+ # fix decoder without updating params
240
+ if fix_decoder:
241
+ for _, param in self.stylegan_decoder.named_parameters():
242
+ param.requires_grad = False
243
+
244
+ # for SFT modulations (scale and shift)
245
+ self.condition_scale = nn.ModuleList()
246
+ self.condition_shift = nn.ModuleList()
247
+ for i in range(3, self.log_size + 1):
248
+ out_channels = channels[f'{2**i}']
249
+ if sft_half:
250
+ sft_out_channels = out_channels
251
+ else:
252
+ sft_out_channels = out_channels * 2
253
+ self.condition_scale.append(
254
+ nn.Sequential(
255
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
256
+ ScaledLeakyReLU(0.2),
257
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
258
+ self.condition_shift.append(
259
+ nn.Sequential(
260
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
261
+ ScaledLeakyReLU(0.2),
262
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
263
+
264
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
265
+ """Forward function for GFPGANBilinear.
266
+
267
+ Args:
268
+ x (Tensor): Input images.
269
+ return_latents (bool): Whether to return style latents. Default: False.
270
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
271
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
272
+ """
273
+ conditions = []
274
+ unet_skips = []
275
+ out_rgbs = []
276
+
277
+ # encoder
278
+ feat = self.conv_body_first(x)
279
+ for i in range(self.log_size - 2):
280
+ feat = self.conv_body_down[i](feat)
281
+ unet_skips.insert(0, feat)
282
+
283
+ feat = self.final_conv(feat)
284
+
285
+ # style code
286
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
287
+ if self.different_w:
288
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
289
+
290
+ # decode
291
+ for i in range(self.log_size - 2):
292
+ # add unet skip
293
+ feat = feat + unet_skips[i]
294
+ # ResUpLayer
295
+ feat = self.conv_body_up[i](feat)
296
+ # generate scale and shift for SFT layers
297
+ scale = self.condition_scale[i](feat)
298
+ conditions.append(scale.clone())
299
+ shift = self.condition_shift[i](feat)
300
+ conditions.append(shift.clone())
301
+ # generate rgb images
302
+ if return_rgb:
303
+ out_rgbs.append(self.toRGB[i](feat))
304
+
305
+ # decoder
306
+ image, _ = self.stylegan_decoder([style_code],
307
+ conditions,
308
+ return_latents=return_latents,
309
+ input_is_latent=self.input_is_latent,
310
+ randomize_noise=randomize_noise)
311
+
312
+ return image, out_rgbs
gfpgan/archs/gfpganv1_arch.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
5
+ StyleGAN2Generator)
6
+ from basicsr.ops.fused_act import FusedLeakyReLU
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
13
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
+
15
+ Args:
16
+ out_size (int): The spatial size of outputs.
17
+ num_style_feat (int): Channel number of style features. Default: 512.
18
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
19
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
20
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
21
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
22
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
23
+ narrow (float): The narrow ratio for channels. Default: 1.
24
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
25
+ """
26
+
27
+ def __init__(self,
28
+ out_size,
29
+ num_style_feat=512,
30
+ num_mlp=8,
31
+ channel_multiplier=2,
32
+ resample_kernel=(1, 3, 3, 1),
33
+ lr_mlp=0.01,
34
+ narrow=1,
35
+ sft_half=False):
36
+ super(StyleGAN2GeneratorSFT, self).__init__(
37
+ out_size,
38
+ num_style_feat=num_style_feat,
39
+ num_mlp=num_mlp,
40
+ channel_multiplier=channel_multiplier,
41
+ resample_kernel=resample_kernel,
42
+ lr_mlp=lr_mlp,
43
+ narrow=narrow)
44
+ self.sft_half = sft_half
45
+
46
+ def forward(self,
47
+ styles,
48
+ conditions,
49
+ input_is_latent=False,
50
+ noise=None,
51
+ randomize_noise=True,
52
+ truncation=1,
53
+ truncation_latent=None,
54
+ inject_index=None,
55
+ return_latents=False):
56
+ """Forward function for StyleGAN2GeneratorSFT.
57
+
58
+ Args:
59
+ styles (list[Tensor]): Sample codes of styles.
60
+ conditions (list[Tensor]): SFT conditions to generators.
61
+ input_is_latent (bool): Whether input is latent style. Default: False.
62
+ noise (Tensor | None): Input noise or None. Default: None.
63
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
64
+ truncation (float): The truncation ratio. Default: 1.
65
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
66
+ inject_index (int | None): The injection index for mixing noise. Default: None.
67
+ return_latents (bool): Whether to return style latents. Default: False.
68
+ """
69
+ # style codes -> latents with Style MLP layer
70
+ if not input_is_latent:
71
+ styles = [self.style_mlp(s) for s in styles]
72
+ # noises
73
+ if noise is None:
74
+ if randomize_noise:
75
+ noise = [None] * self.num_layers # for each style conv layer
76
+ else: # use the stored noise
77
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
78
+ # style truncation
79
+ if truncation < 1:
80
+ style_truncation = []
81
+ for style in styles:
82
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
83
+ styles = style_truncation
84
+ # get style latents with injection
85
+ if len(styles) == 1:
86
+ inject_index = self.num_latent
87
+
88
+ if styles[0].ndim < 3:
89
+ # repeat latent code for all the layers
90
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
91
+ else: # used for encoder with different latent code for each layer
92
+ latent = styles[0]
93
+ elif len(styles) == 2: # mixing noises
94
+ if inject_index is None:
95
+ inject_index = random.randint(1, self.num_latent - 1)
96
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
97
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
98
+ latent = torch.cat([latent1, latent2], 1)
99
+
100
+ # main generation
101
+ out = self.constant_input(latent.shape[0])
102
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
103
+ skip = self.to_rgb1(out, latent[:, 1])
104
+
105
+ i = 1
106
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
107
+ noise[2::2], self.to_rgbs):
108
+ out = conv1(out, latent[:, i], noise=noise1)
109
+
110
+ # the conditions may have fewer levels
111
+ if i < len(conditions):
112
+ # SFT part to combine the conditions
113
+ if self.sft_half: # only apply SFT to half of the channels
114
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
115
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
116
+ out = torch.cat([out_same, out_sft], dim=1)
117
+ else: # apply SFT to all the channels
118
+ out = out * conditions[i - 1] + conditions[i]
119
+
120
+ out = conv2(out, latent[:, i + 1], noise=noise2)
121
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
122
+ i += 2
123
+
124
+ image = skip
125
+
126
+ if return_latents:
127
+ return image, latent
128
+ else:
129
+ return image, None
130
+
131
+
132
+ class ConvUpLayer(nn.Module):
133
+ """Convolutional upsampling layer. It uses bilinear upsampler + Conv.
134
+
135
+ Args:
136
+ in_channels (int): Channel number of the input.
137
+ out_channels (int): Channel number of the output.
138
+ kernel_size (int): Size of the convolving kernel.
139
+ stride (int): Stride of the convolution. Default: 1
140
+ padding (int): Zero-padding added to both sides of the input. Default: 0.
141
+ bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
142
+ bias_init_val (float): Bias initialized value. Default: 0.
143
+ activate (bool): Whether use activateion. Default: True.
144
+ """
145
+
146
+ def __init__(self,
147
+ in_channels,
148
+ out_channels,
149
+ kernel_size,
150
+ stride=1,
151
+ padding=0,
152
+ bias=True,
153
+ bias_init_val=0,
154
+ activate=True):
155
+ super(ConvUpLayer, self).__init__()
156
+ self.in_channels = in_channels
157
+ self.out_channels = out_channels
158
+ self.kernel_size = kernel_size
159
+ self.stride = stride
160
+ self.padding = padding
161
+ # self.scale is used to scale the convolution weights, which is related to the common initializations.
162
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
163
+
164
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
165
+
166
+ if bias and not activate:
167
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
168
+ else:
169
+ self.register_parameter('bias', None)
170
+
171
+ # activation
172
+ if activate:
173
+ if bias:
174
+ self.activation = FusedLeakyReLU(out_channels)
175
+ else:
176
+ self.activation = ScaledLeakyReLU(0.2)
177
+ else:
178
+ self.activation = None
179
+
180
+ def forward(self, x):
181
+ # bilinear upsample
182
+ out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
183
+ # conv
184
+ out = F.conv2d(
185
+ out,
186
+ self.weight * self.scale,
187
+ bias=self.bias,
188
+ stride=self.stride,
189
+ padding=self.padding,
190
+ )
191
+ # activation
192
+ if self.activation is not None:
193
+ out = self.activation(out)
194
+ return out
195
+
196
+
197
+ class ResUpBlock(nn.Module):
198
+ """Residual block with upsampling.
199
+
200
+ Args:
201
+ in_channels (int): Channel number of the input.
202
+ out_channels (int): Channel number of the output.
203
+ """
204
+
205
+ def __init__(self, in_channels, out_channels):
206
+ super(ResUpBlock, self).__init__()
207
+
208
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
209
+ self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
210
+ self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
211
+
212
+ def forward(self, x):
213
+ out = self.conv1(x)
214
+ out = self.conv2(out)
215
+ skip = self.skip(x)
216
+ out = (out + skip) / math.sqrt(2)
217
+ return out
218
+
219
+
220
+ @ARCH_REGISTRY.register()
221
+ class GFPGANv1(nn.Module):
222
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
223
+
224
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
225
+
226
+ Args:
227
+ out_size (int): The spatial size of outputs.
228
+ num_style_feat (int): Channel number of style features. Default: 512.
229
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
230
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
231
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
232
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
233
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
234
+
235
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
236
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
237
+ input_is_latent (bool): Whether input is latent style. Default: False.
238
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
239
+ narrow (float): The narrow ratio for channels. Default: 1.
240
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ out_size,
246
+ num_style_feat=512,
247
+ channel_multiplier=1,
248
+ resample_kernel=(1, 3, 3, 1),
249
+ decoder_load_path=None,
250
+ fix_decoder=True,
251
+ # for stylegan decoder
252
+ num_mlp=8,
253
+ lr_mlp=0.01,
254
+ input_is_latent=False,
255
+ different_w=False,
256
+ narrow=1,
257
+ sft_half=False):
258
+
259
+ super(GFPGANv1, self).__init__()
260
+ self.input_is_latent = input_is_latent
261
+ self.different_w = different_w
262
+ self.num_style_feat = num_style_feat
263
+
264
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
265
+ channels = {
266
+ '4': int(512 * unet_narrow),
267
+ '8': int(512 * unet_narrow),
268
+ '16': int(512 * unet_narrow),
269
+ '32': int(512 * unet_narrow),
270
+ '64': int(256 * channel_multiplier * unet_narrow),
271
+ '128': int(128 * channel_multiplier * unet_narrow),
272
+ '256': int(64 * channel_multiplier * unet_narrow),
273
+ '512': int(32 * channel_multiplier * unet_narrow),
274
+ '1024': int(16 * channel_multiplier * unet_narrow)
275
+ }
276
+
277
+ self.log_size = int(math.log(out_size, 2))
278
+ first_out_size = 2**(int(math.log(out_size, 2)))
279
+
280
+ self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
281
+
282
+ # downsample
283
+ in_channels = channels[f'{first_out_size}']
284
+ self.conv_body_down = nn.ModuleList()
285
+ for i in range(self.log_size, 2, -1):
286
+ out_channels = channels[f'{2**(i - 1)}']
287
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
288
+ in_channels = out_channels
289
+
290
+ self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
291
+
292
+ # upsample
293
+ in_channels = channels['4']
294
+ self.conv_body_up = nn.ModuleList()
295
+ for i in range(3, self.log_size + 1):
296
+ out_channels = channels[f'{2**i}']
297
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
298
+ in_channels = out_channels
299
+
300
+ # to RGB
301
+ self.toRGB = nn.ModuleList()
302
+ for i in range(3, self.log_size + 1):
303
+ self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
304
+
305
+ if different_w:
306
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
307
+ else:
308
+ linear_out_channel = num_style_feat
309
+
310
+ self.final_linear = EqualLinear(
311
+ channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
312
+
313
+ # the decoder: stylegan2 generator with SFT modulations
314
+ self.stylegan_decoder = StyleGAN2GeneratorSFT(
315
+ out_size=out_size,
316
+ num_style_feat=num_style_feat,
317
+ num_mlp=num_mlp,
318
+ channel_multiplier=channel_multiplier,
319
+ resample_kernel=resample_kernel,
320
+ lr_mlp=lr_mlp,
321
+ narrow=narrow,
322
+ sft_half=sft_half)
323
+
324
+ # load pre-trained stylegan2 model if necessary
325
+ if decoder_load_path:
326
+ self.stylegan_decoder.load_state_dict(
327
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
328
+ # fix decoder without updating params
329
+ if fix_decoder:
330
+ for _, param in self.stylegan_decoder.named_parameters():
331
+ param.requires_grad = False
332
+
333
+ # for SFT modulations (scale and shift)
334
+ self.condition_scale = nn.ModuleList()
335
+ self.condition_shift = nn.ModuleList()
336
+ for i in range(3, self.log_size + 1):
337
+ out_channels = channels[f'{2**i}']
338
+ if sft_half:
339
+ sft_out_channels = out_channels
340
+ else:
341
+ sft_out_channels = out_channels * 2
342
+ self.condition_scale.append(
343
+ nn.Sequential(
344
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
345
+ ScaledLeakyReLU(0.2),
346
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
347
+ self.condition_shift.append(
348
+ nn.Sequential(
349
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
350
+ ScaledLeakyReLU(0.2),
351
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
352
+
353
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
354
+ """Forward function for GFPGANv1.
355
+
356
+ Args:
357
+ x (Tensor): Input images.
358
+ return_latents (bool): Whether to return style latents. Default: False.
359
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
360
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
361
+ """
362
+ conditions = []
363
+ unet_skips = []
364
+ out_rgbs = []
365
+
366
+ # encoder
367
+ feat = self.conv_body_first(x)
368
+ for i in range(self.log_size - 2):
369
+ feat = self.conv_body_down[i](feat)
370
+ unet_skips.insert(0, feat)
371
+
372
+ feat = self.final_conv(feat)
373
+
374
+ # style code
375
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
376
+ if self.different_w:
377
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
378
+
379
+ # decode
380
+ for i in range(self.log_size - 2):
381
+ # add unet skip
382
+ feat = feat + unet_skips[i]
383
+ # ResUpLayer
384
+ feat = self.conv_body_up[i](feat)
385
+ # generate scale and shift for SFT layers
386
+ scale = self.condition_scale[i](feat)
387
+ conditions.append(scale.clone())
388
+ shift = self.condition_shift[i](feat)
389
+ conditions.append(shift.clone())
390
+ # generate rgb images
391
+ if return_rgb:
392
+ out_rgbs.append(self.toRGB[i](feat))
393
+
394
+ # decoder
395
+ image, _ = self.stylegan_decoder([style_code],
396
+ conditions,
397
+ return_latents=return_latents,
398
+ input_is_latent=self.input_is_latent,
399
+ randomize_noise=randomize_noise)
400
+
401
+ return image, out_rgbs
402
+
403
+
404
+ @ARCH_REGISTRY.register()
405
+ class FacialComponentDiscriminator(nn.Module):
406
+ """Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
407
+ """
408
+
409
+ def __init__(self):
410
+ super(FacialComponentDiscriminator, self).__init__()
411
+ # It now uses a VGG-style architectrue with fixed model size
412
+ self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
413
+ self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
414
+ self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
415
+ self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
416
+ self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
417
+ self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
418
+
419
+ def forward(self, x, return_feats=False, **kwargs):
420
+ """Forward function for FacialComponentDiscriminator.
421
+
422
+ Args:
423
+ x (Tensor): Input images.
424
+ return_feats (bool): Whether to return intermediate features. Default: False.
425
+ """
426
+ feat = self.conv1(x)
427
+ feat = self.conv3(self.conv2(feat))
428
+ rlt_feats = []
429
+ if return_feats:
430
+ rlt_feats.append(feat.clone())
431
+ feat = self.conv5(self.conv4(feat))
432
+ if return_feats:
433
+ rlt_feats.append(feat.clone())
434
+ out = self.final_conv(feat)
435
+
436
+ if return_feats:
437
+ return out, rlt_feats
438
+ else:
439
+ return out, None
gfpgan/archs/gfpganv1_clean_arch.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
9
+
10
+
11
+ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
12
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
13
+
14
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
15
+
16
+ Args:
17
+ out_size (int): The spatial size of outputs.
18
+ num_style_feat (int): Channel number of style features. Default: 512.
19
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
20
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
21
+ narrow (float): The narrow ratio for channels. Default: 1.
22
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
23
+ """
24
+
25
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
26
+ super(StyleGAN2GeneratorCSFT, self).__init__(
27
+ out_size,
28
+ num_style_feat=num_style_feat,
29
+ num_mlp=num_mlp,
30
+ channel_multiplier=channel_multiplier,
31
+ narrow=narrow)
32
+ self.sft_half = sft_half
33
+
34
+ def forward(self,
35
+ styles,
36
+ conditions,
37
+ input_is_latent=False,
38
+ noise=None,
39
+ randomize_noise=True,
40
+ truncation=1,
41
+ truncation_latent=None,
42
+ inject_index=None,
43
+ return_latents=False):
44
+ """Forward function for StyleGAN2GeneratorCSFT.
45
+
46
+ Args:
47
+ styles (list[Tensor]): Sample codes of styles.
48
+ conditions (list[Tensor]): SFT conditions to generators.
49
+ input_is_latent (bool): Whether input is latent style. Default: False.
50
+ noise (Tensor | None): Input noise or None. Default: None.
51
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
52
+ truncation (float): The truncation ratio. Default: 1.
53
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
54
+ inject_index (int | None): The injection index for mixing noise. Default: None.
55
+ return_latents (bool): Whether to return style latents. Default: False.
56
+ """
57
+ # style codes -> latents with Style MLP layer
58
+ if not input_is_latent:
59
+ styles = [self.style_mlp(s) for s in styles]
60
+ # noises
61
+ if noise is None:
62
+ if randomize_noise:
63
+ noise = [None] * self.num_layers # for each style conv layer
64
+ else: # use the stored noise
65
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
66
+ # style truncation
67
+ if truncation < 1:
68
+ style_truncation = []
69
+ for style in styles:
70
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
71
+ styles = style_truncation
72
+ # get style latents with injection
73
+ if len(styles) == 1:
74
+ inject_index = self.num_latent
75
+
76
+ if styles[0].ndim < 3:
77
+ # repeat latent code for all the layers
78
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
79
+ else: # used for encoder with different latent code for each layer
80
+ latent = styles[0]
81
+ elif len(styles) == 2: # mixing noises
82
+ if inject_index is None:
83
+ inject_index = random.randint(1, self.num_latent - 1)
84
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
85
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
86
+ latent = torch.cat([latent1, latent2], 1)
87
+
88
+ # main generation
89
+ out = self.constant_input(latent.shape[0])
90
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
91
+ skip = self.to_rgb1(out, latent[:, 1])
92
+
93
+ i = 1
94
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
95
+ noise[2::2], self.to_rgbs):
96
+ out = conv1(out, latent[:, i], noise=noise1)
97
+
98
+ # the conditions may have fewer levels
99
+ if i < len(conditions):
100
+ # SFT part to combine the conditions
101
+ if self.sft_half: # only apply SFT to half of the channels
102
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
103
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
104
+ out = torch.cat([out_same, out_sft], dim=1)
105
+ else: # apply SFT to all the channels
106
+ out = out * conditions[i - 1] + conditions[i]
107
+
108
+ out = conv2(out, latent[:, i + 1], noise=noise2)
109
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
110
+ i += 2
111
+
112
+ image = skip
113
+
114
+ if return_latents:
115
+ return image, latent
116
+ else:
117
+ return image, None
118
+
119
+
120
+ class ResBlock(nn.Module):
121
+ """Residual block with bilinear upsampling/downsampling.
122
+
123
+ Args:
124
+ in_channels (int): Channel number of the input.
125
+ out_channels (int): Channel number of the output.
126
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
127
+ """
128
+
129
+ def __init__(self, in_channels, out_channels, mode='down'):
130
+ super(ResBlock, self).__init__()
131
+
132
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
133
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
134
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
135
+ if mode == 'down':
136
+ self.scale_factor = 0.5
137
+ elif mode == 'up':
138
+ self.scale_factor = 2
139
+
140
+ def forward(self, x):
141
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
142
+ # upsample/downsample
143
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
144
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
145
+ # skip
146
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
147
+ skip = self.skip(x)
148
+ out = out + skip
149
+ return out
150
+
151
+
152
+ @ARCH_REGISTRY.register()
153
+ class GFPGANv1Clean(nn.Module):
154
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
155
+
156
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
157
+
158
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
159
+
160
+ Args:
161
+ out_size (int): The spatial size of outputs.
162
+ num_style_feat (int): Channel number of style features. Default: 512.
163
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
164
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
165
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
166
+
167
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
168
+ input_is_latent (bool): Whether input is latent style. Default: False.
169
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
170
+ narrow (float): The narrow ratio for channels. Default: 1.
171
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ out_size,
177
+ num_style_feat=512,
178
+ channel_multiplier=1,
179
+ decoder_load_path=None,
180
+ fix_decoder=True,
181
+ # for stylegan decoder
182
+ num_mlp=8,
183
+ input_is_latent=False,
184
+ different_w=False,
185
+ narrow=1,
186
+ sft_half=False):
187
+
188
+ super(GFPGANv1Clean, self).__init__()
189
+ self.input_is_latent = input_is_latent
190
+ self.different_w = different_w
191
+ self.num_style_feat = num_style_feat
192
+
193
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
194
+ channels = {
195
+ '4': int(512 * unet_narrow),
196
+ '8': int(512 * unet_narrow),
197
+ '16': int(512 * unet_narrow),
198
+ '32': int(512 * unet_narrow),
199
+ '64': int(256 * channel_multiplier * unet_narrow),
200
+ '128': int(128 * channel_multiplier * unet_narrow),
201
+ '256': int(64 * channel_multiplier * unet_narrow),
202
+ '512': int(32 * channel_multiplier * unet_narrow),
203
+ '1024': int(16 * channel_multiplier * unet_narrow)
204
+ }
205
+
206
+ self.log_size = int(math.log(out_size, 2))
207
+ first_out_size = 2**(int(math.log(out_size, 2)))
208
+
209
+ self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
210
+
211
+ # downsample
212
+ in_channels = channels[f'{first_out_size}']
213
+ self.conv_body_down = nn.ModuleList()
214
+ for i in range(self.log_size, 2, -1):
215
+ out_channels = channels[f'{2**(i - 1)}']
216
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
217
+ in_channels = out_channels
218
+
219
+ self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
220
+
221
+ # upsample
222
+ in_channels = channels['4']
223
+ self.conv_body_up = nn.ModuleList()
224
+ for i in range(3, self.log_size + 1):
225
+ out_channels = channels[f'{2**i}']
226
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
227
+ in_channels = out_channels
228
+
229
+ # to RGB
230
+ self.toRGB = nn.ModuleList()
231
+ for i in range(3, self.log_size + 1):
232
+ self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
233
+
234
+ if different_w:
235
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
236
+ else:
237
+ linear_out_channel = num_style_feat
238
+
239
+ self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
240
+
241
+ # the decoder: stylegan2 generator with SFT modulations
242
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
243
+ out_size=out_size,
244
+ num_style_feat=num_style_feat,
245
+ num_mlp=num_mlp,
246
+ channel_multiplier=channel_multiplier,
247
+ narrow=narrow,
248
+ sft_half=sft_half)
249
+
250
+ # load pre-trained stylegan2 model if necessary
251
+ if decoder_load_path:
252
+ self.stylegan_decoder.load_state_dict(
253
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
254
+ # fix decoder without updating params
255
+ if fix_decoder:
256
+ for _, param in self.stylegan_decoder.named_parameters():
257
+ param.requires_grad = False
258
+
259
+ # for SFT modulations (scale and shift)
260
+ self.condition_scale = nn.ModuleList()
261
+ self.condition_shift = nn.ModuleList()
262
+ for i in range(3, self.log_size + 1):
263
+ out_channels = channels[f'{2**i}']
264
+ if sft_half:
265
+ sft_out_channels = out_channels
266
+ else:
267
+ sft_out_channels = out_channels * 2
268
+ self.condition_scale.append(
269
+ nn.Sequential(
270
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
271
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
272
+ self.condition_shift.append(
273
+ nn.Sequential(
274
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
275
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
276
+
277
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
278
+ """Forward function for GFPGANv1Clean.
279
+
280
+ Args:
281
+ x (Tensor): Input images.
282
+ return_latents (bool): Whether to return style latents. Default: False.
283
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
284
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
285
+ """
286
+ conditions = []
287
+ unet_skips = []
288
+ out_rgbs = []
289
+
290
+ # encoder
291
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
292
+ for i in range(self.log_size - 2):
293
+ feat = self.conv_body_down[i](feat)
294
+ unet_skips.insert(0, feat)
295
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
296
+
297
+ # style code
298
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
299
+ if self.different_w:
300
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
301
+
302
+ # decode
303
+ for i in range(self.log_size - 2):
304
+ # add unet skip
305
+ feat = feat + unet_skips[i]
306
+ # ResUpLayer
307
+ feat = self.conv_body_up[i](feat)
308
+ # generate scale and shift for SFT layers
309
+ scale = self.condition_scale[i](feat)
310
+ conditions.append(scale.clone())
311
+ shift = self.condition_shift[i](feat)
312
+ conditions.append(shift.clone())
313
+ # generate rgb images
314
+ if return_rgb:
315
+ out_rgbs.append(self.toRGB[i](feat))
316
+
317
+ # decoder
318
+ image, _ = self.stylegan_decoder([style_code],
319
+ conditions,
320
+ return_latents=return_latents,
321
+ input_is_latent=self.input_is_latent,
322
+ randomize_noise=randomize_noise)
323
+
324
+ return image, out_rgbs
gfpgan/archs/restoreformer_arch.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/wzhouxiff/RestoreFormer
2
+ """
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ """
11
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
+ ____________________________________________
13
+ Discretization bottleneck part of the VQ-VAE.
14
+ Inputs:
15
+ - n_e : number of embeddings
16
+ - e_dim : dimension of embedding
17
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
+ _____________________________________________
19
+ """
20
+
21
+ def __init__(self, n_e, e_dim, beta):
22
+ super(VectorQuantizer, self).__init__()
23
+ self.n_e = n_e
24
+ self.e_dim = e_dim
25
+ self.beta = beta
26
+
27
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
28
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
29
+
30
+ def forward(self, z):
31
+ """
32
+ Inputs the output of the encoder network z and maps it to a discrete
33
+ one-hot vector that is the index of the closest embedding vector e_j
34
+ z (continuous) -> z_q (discrete)
35
+ z.shape = (batch, channel, height, width)
36
+ quantization pipeline:
37
+ 1. get encoder input (B,C,H,W)
38
+ 2. flatten input to (B*H*W,C)
39
+ """
40
+ # reshape z -> (batch, height, width, channel) and flatten
41
+ z = z.permute(0, 2, 3, 1).contiguous()
42
+ z_flattened = z.view(-1, self.e_dim)
43
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
44
+
45
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
46
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
47
+ torch.matmul(z_flattened, self.embedding.weight.t())
48
+
49
+ # could possible replace this here
50
+ # #\start...
51
+ # find closest encodings
52
+
53
+ min_value, min_encoding_indices = torch.min(d, dim=1)
54
+
55
+ min_encoding_indices = min_encoding_indices.unsqueeze(1)
56
+
57
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
58
+ min_encodings.scatter_(1, min_encoding_indices, 1)
59
+
60
+ # dtype min encodings: torch.float32
61
+ # min_encodings shape: torch.Size([2048, 512])
62
+ # min_encoding_indices.shape: torch.Size([2048, 1])
63
+
64
+ # get quantized latent vectors
65
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
66
+ # .........\end
67
+
68
+ # with:
69
+ # .........\start
70
+ # min_encoding_indices = torch.argmin(d, dim=1)
71
+ # z_q = self.embedding(min_encoding_indices)
72
+ # ......\end......... (TODO)
73
+
74
+ # compute loss for embedding
75
+ loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
76
+
77
+ # preserve gradients
78
+ z_q = z + (z_q - z).detach()
79
+
80
+ # perplexity
81
+
82
+ e_mean = torch.mean(min_encodings, dim=0)
83
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
84
+
85
+ # reshape back to match original input shape
86
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
87
+
88
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
89
+
90
+ def get_codebook_entry(self, indices, shape):
91
+ # shape specifying (batch, height, width, channel)
92
+ # TODO: check for more easy handling with nn.Embedding
93
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
94
+ min_encodings.scatter_(1, indices[:, None], 1)
95
+
96
+ # get quantized latent vectors
97
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
98
+
99
+ if shape is not None:
100
+ z_q = z_q.view(shape)
101
+
102
+ # reshape back to match original input shape
103
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
104
+
105
+ return z_q
106
+
107
+
108
+ # pytorch_diffusion + derived encoder decoder
109
+ def nonlinearity(x):
110
+ # swish
111
+ return x * torch.sigmoid(x)
112
+
113
+
114
+ def Normalize(in_channels):
115
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
116
+
117
+
118
+ class Upsample(nn.Module):
119
+
120
+ def __init__(self, in_channels, with_conv):
121
+ super().__init__()
122
+ self.with_conv = with_conv
123
+ if self.with_conv:
124
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
125
+
126
+ def forward(self, x):
127
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
128
+ if self.with_conv:
129
+ x = self.conv(x)
130
+ return x
131
+
132
+
133
+ class Downsample(nn.Module):
134
+
135
+ def __init__(self, in_channels, with_conv):
136
+ super().__init__()
137
+ self.with_conv = with_conv
138
+ if self.with_conv:
139
+ # no asymmetric padding in torch conv, must do it ourselves
140
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
141
+
142
+ def forward(self, x):
143
+ if self.with_conv:
144
+ pad = (0, 1, 0, 1)
145
+ x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
146
+ x = self.conv(x)
147
+ else:
148
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
149
+ return x
150
+
151
+
152
+ class ResnetBlock(nn.Module):
153
+
154
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
155
+ super().__init__()
156
+ self.in_channels = in_channels
157
+ out_channels = in_channels if out_channels is None else out_channels
158
+ self.out_channels = out_channels
159
+ self.use_conv_shortcut = conv_shortcut
160
+
161
+ self.norm1 = Normalize(in_channels)
162
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+ if temb_channels > 0:
164
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
165
+ self.norm2 = Normalize(out_channels)
166
+ self.dropout = torch.nn.Dropout(dropout)
167
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168
+ if self.in_channels != self.out_channels:
169
+ if self.use_conv_shortcut:
170
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
171
+ else:
172
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173
+
174
+ def forward(self, x, temb):
175
+ h = x
176
+ h = self.norm1(h)
177
+ h = nonlinearity(h)
178
+ h = self.conv1(h)
179
+
180
+ if temb is not None:
181
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
182
+
183
+ h = self.norm2(h)
184
+ h = nonlinearity(h)
185
+ h = self.dropout(h)
186
+ h = self.conv2(h)
187
+
188
+ if self.in_channels != self.out_channels:
189
+ if self.use_conv_shortcut:
190
+ x = self.conv_shortcut(x)
191
+ else:
192
+ x = self.nin_shortcut(x)
193
+
194
+ return x + h
195
+
196
+
197
+ class MultiHeadAttnBlock(nn.Module):
198
+
199
+ def __init__(self, in_channels, head_size=1):
200
+ super().__init__()
201
+ self.in_channels = in_channels
202
+ self.head_size = head_size
203
+ self.att_size = in_channels // head_size
204
+ assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
205
+
206
+ self.norm1 = Normalize(in_channels)
207
+ self.norm2 = Normalize(in_channels)
208
+
209
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
210
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
211
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
212
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
213
+ self.num = 0
214
+
215
+ def forward(self, x, y=None):
216
+ h_ = x
217
+ h_ = self.norm1(h_)
218
+ if y is None:
219
+ y = h_
220
+ else:
221
+ y = self.norm2(y)
222
+
223
+ q = self.q(y)
224
+ k = self.k(h_)
225
+ v = self.v(h_)
226
+
227
+ # compute attention
228
+ b, c, h, w = q.shape
229
+ q = q.reshape(b, self.head_size, self.att_size, h * w)
230
+ q = q.permute(0, 3, 1, 2) # b, hw, head, att
231
+
232
+ k = k.reshape(b, self.head_size, self.att_size, h * w)
233
+ k = k.permute(0, 3, 1, 2)
234
+
235
+ v = v.reshape(b, self.head_size, self.att_size, h * w)
236
+ v = v.permute(0, 3, 1, 2)
237
+
238
+ q = q.transpose(1, 2)
239
+ v = v.transpose(1, 2)
240
+ k = k.transpose(1, 2).transpose(2, 3)
241
+
242
+ scale = int(self.att_size)**(-0.5)
243
+ q.mul_(scale)
244
+ w_ = torch.matmul(q, k)
245
+ w_ = F.softmax(w_, dim=3)
246
+
247
+ w_ = w_.matmul(v)
248
+
249
+ w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
250
+ w_ = w_.view(b, h, w, -1)
251
+ w_ = w_.permute(0, 3, 1, 2)
252
+
253
+ w_ = self.proj_out(w_)
254
+
255
+ return x + w_
256
+
257
+
258
+ class MultiHeadEncoder(nn.Module):
259
+
260
+ def __init__(self,
261
+ ch,
262
+ out_ch,
263
+ ch_mult=(1, 2, 4, 8),
264
+ num_res_blocks=2,
265
+ attn_resolutions=(16, ),
266
+ dropout=0.0,
267
+ resamp_with_conv=True,
268
+ in_channels=3,
269
+ resolution=512,
270
+ z_channels=256,
271
+ double_z=True,
272
+ enable_mid=True,
273
+ head_size=1,
274
+ **ignore_kwargs):
275
+ super().__init__()
276
+ self.ch = ch
277
+ self.temb_ch = 0
278
+ self.num_resolutions = len(ch_mult)
279
+ self.num_res_blocks = num_res_blocks
280
+ self.resolution = resolution
281
+ self.in_channels = in_channels
282
+ self.enable_mid = enable_mid
283
+
284
+ # downsampling
285
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
286
+
287
+ curr_res = resolution
288
+ in_ch_mult = (1, ) + tuple(ch_mult)
289
+ self.down = nn.ModuleList()
290
+ for i_level in range(self.num_resolutions):
291
+ block = nn.ModuleList()
292
+ attn = nn.ModuleList()
293
+ block_in = ch * in_ch_mult[i_level]
294
+ block_out = ch * ch_mult[i_level]
295
+ for i_block in range(self.num_res_blocks):
296
+ block.append(
297
+ ResnetBlock(
298
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
299
+ block_in = block_out
300
+ if curr_res in attn_resolutions:
301
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
302
+ down = nn.Module()
303
+ down.block = block
304
+ down.attn = attn
305
+ if i_level != self.num_resolutions - 1:
306
+ down.downsample = Downsample(block_in, resamp_with_conv)
307
+ curr_res = curr_res // 2
308
+ self.down.append(down)
309
+
310
+ # middle
311
+ if self.enable_mid:
312
+ self.mid = nn.Module()
313
+ self.mid.block_1 = ResnetBlock(
314
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
315
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
316
+ self.mid.block_2 = ResnetBlock(
317
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
318
+
319
+ # end
320
+ self.norm_out = Normalize(block_in)
321
+ self.conv_out = torch.nn.Conv2d(
322
+ block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
323
+
324
+ def forward(self, x):
325
+ hs = {}
326
+ # timestep embedding
327
+ temb = None
328
+
329
+ # downsampling
330
+ h = self.conv_in(x)
331
+ hs['in'] = h
332
+ for i_level in range(self.num_resolutions):
333
+ for i_block in range(self.num_res_blocks):
334
+ h = self.down[i_level].block[i_block](h, temb)
335
+ if len(self.down[i_level].attn) > 0:
336
+ h = self.down[i_level].attn[i_block](h)
337
+
338
+ if i_level != self.num_resolutions - 1:
339
+ # hs.append(h)
340
+ hs['block_' + str(i_level)] = h
341
+ h = self.down[i_level].downsample(h)
342
+
343
+ # middle
344
+ # h = hs[-1]
345
+ if self.enable_mid:
346
+ h = self.mid.block_1(h, temb)
347
+ hs['block_' + str(i_level) + '_atten'] = h
348
+ h = self.mid.attn_1(h)
349
+ h = self.mid.block_2(h, temb)
350
+ hs['mid_atten'] = h
351
+
352
+ # end
353
+ h = self.norm_out(h)
354
+ h = nonlinearity(h)
355
+ h = self.conv_out(h)
356
+ # hs.append(h)
357
+ hs['out'] = h
358
+
359
+ return hs
360
+
361
+
362
+ class MultiHeadDecoder(nn.Module):
363
+
364
+ def __init__(self,
365
+ ch,
366
+ out_ch,
367
+ ch_mult=(1, 2, 4, 8),
368
+ num_res_blocks=2,
369
+ attn_resolutions=(16, ),
370
+ dropout=0.0,
371
+ resamp_with_conv=True,
372
+ in_channels=3,
373
+ resolution=512,
374
+ z_channels=256,
375
+ give_pre_end=False,
376
+ enable_mid=True,
377
+ head_size=1,
378
+ **ignorekwargs):
379
+ super().__init__()
380
+ self.ch = ch
381
+ self.temb_ch = 0
382
+ self.num_resolutions = len(ch_mult)
383
+ self.num_res_blocks = num_res_blocks
384
+ self.resolution = resolution
385
+ self.in_channels = in_channels
386
+ self.give_pre_end = give_pre_end
387
+ self.enable_mid = enable_mid
388
+
389
+ # compute in_ch_mult, block_in and curr_res at lowest res
390
+ block_in = ch * ch_mult[self.num_resolutions - 1]
391
+ curr_res = resolution // 2**(self.num_resolutions - 1)
392
+ self.z_shape = (1, z_channels, curr_res, curr_res)
393
+ print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
394
+
395
+ # z to block_in
396
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
397
+
398
+ # middle
399
+ if self.enable_mid:
400
+ self.mid = nn.Module()
401
+ self.mid.block_1 = ResnetBlock(
402
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
403
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
404
+ self.mid.block_2 = ResnetBlock(
405
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
406
+
407
+ # upsampling
408
+ self.up = nn.ModuleList()
409
+ for i_level in reversed(range(self.num_resolutions)):
410
+ block = nn.ModuleList()
411
+ attn = nn.ModuleList()
412
+ block_out = ch * ch_mult[i_level]
413
+ for i_block in range(self.num_res_blocks + 1):
414
+ block.append(
415
+ ResnetBlock(
416
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
417
+ block_in = block_out
418
+ if curr_res in attn_resolutions:
419
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
420
+ up = nn.Module()
421
+ up.block = block
422
+ up.attn = attn
423
+ if i_level != 0:
424
+ up.upsample = Upsample(block_in, resamp_with_conv)
425
+ curr_res = curr_res * 2
426
+ self.up.insert(0, up) # prepend to get consistent order
427
+
428
+ # end
429
+ self.norm_out = Normalize(block_in)
430
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
431
+
432
+ def forward(self, z):
433
+ # assert z.shape[1:] == self.z_shape[1:]
434
+ self.last_z_shape = z.shape
435
+
436
+ # timestep embedding
437
+ temb = None
438
+
439
+ # z to block_in
440
+ h = self.conv_in(z)
441
+
442
+ # middle
443
+ if self.enable_mid:
444
+ h = self.mid.block_1(h, temb)
445
+ h = self.mid.attn_1(h)
446
+ h = self.mid.block_2(h, temb)
447
+
448
+ # upsampling
449
+ for i_level in reversed(range(self.num_resolutions)):
450
+ for i_block in range(self.num_res_blocks + 1):
451
+ h = self.up[i_level].block[i_block](h, temb)
452
+ if len(self.up[i_level].attn) > 0:
453
+ h = self.up[i_level].attn[i_block](h)
454
+ if i_level != 0:
455
+ h = self.up[i_level].upsample(h)
456
+
457
+ # end
458
+ if self.give_pre_end:
459
+ return h
460
+
461
+ h = self.norm_out(h)
462
+ h = nonlinearity(h)
463
+ h = self.conv_out(h)
464
+ return h
465
+
466
+
467
+ class MultiHeadDecoderTransformer(nn.Module):
468
+
469
+ def __init__(self,
470
+ ch,
471
+ out_ch,
472
+ ch_mult=(1, 2, 4, 8),
473
+ num_res_blocks=2,
474
+ attn_resolutions=(16, ),
475
+ dropout=0.0,
476
+ resamp_with_conv=True,
477
+ in_channels=3,
478
+ resolution=512,
479
+ z_channels=256,
480
+ give_pre_end=False,
481
+ enable_mid=True,
482
+ head_size=1,
483
+ **ignorekwargs):
484
+ super().__init__()
485
+ self.ch = ch
486
+ self.temb_ch = 0
487
+ self.num_resolutions = len(ch_mult)
488
+ self.num_res_blocks = num_res_blocks
489
+ self.resolution = resolution
490
+ self.in_channels = in_channels
491
+ self.give_pre_end = give_pre_end
492
+ self.enable_mid = enable_mid
493
+
494
+ # compute in_ch_mult, block_in and curr_res at lowest res
495
+ block_in = ch * ch_mult[self.num_resolutions - 1]
496
+ curr_res = resolution // 2**(self.num_resolutions - 1)
497
+ self.z_shape = (1, z_channels, curr_res, curr_res)
498
+ print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
499
+
500
+ # z to block_in
501
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
502
+
503
+ # middle
504
+ if self.enable_mid:
505
+ self.mid = nn.Module()
506
+ self.mid.block_1 = ResnetBlock(
507
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
508
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
509
+ self.mid.block_2 = ResnetBlock(
510
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
511
+
512
+ # upsampling
513
+ self.up = nn.ModuleList()
514
+ for i_level in reversed(range(self.num_resolutions)):
515
+ block = nn.ModuleList()
516
+ attn = nn.ModuleList()
517
+ block_out = ch * ch_mult[i_level]
518
+ for i_block in range(self.num_res_blocks + 1):
519
+ block.append(
520
+ ResnetBlock(
521
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
522
+ block_in = block_out
523
+ if curr_res in attn_resolutions:
524
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
525
+ up = nn.Module()
526
+ up.block = block
527
+ up.attn = attn
528
+ if i_level != 0:
529
+ up.upsample = Upsample(block_in, resamp_with_conv)
530
+ curr_res = curr_res * 2
531
+ self.up.insert(0, up) # prepend to get consistent order
532
+
533
+ # end
534
+ self.norm_out = Normalize(block_in)
535
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
536
+
537
+ def forward(self, z, hs):
538
+ # assert z.shape[1:] == self.z_shape[1:]
539
+ # self.last_z_shape = z.shape
540
+
541
+ # timestep embedding
542
+ temb = None
543
+
544
+ # z to block_in
545
+ h = self.conv_in(z)
546
+
547
+ # middle
548
+ if self.enable_mid:
549
+ h = self.mid.block_1(h, temb)
550
+ h = self.mid.attn_1(h, hs['mid_atten'])
551
+ h = self.mid.block_2(h, temb)
552
+
553
+ # upsampling
554
+ for i_level in reversed(range(self.num_resolutions)):
555
+ for i_block in range(self.num_res_blocks + 1):
556
+ h = self.up[i_level].block[i_block](h, temb)
557
+ if len(self.up[i_level].attn) > 0:
558
+ h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
559
+ # hfeature = h.clone()
560
+ if i_level != 0:
561
+ h = self.up[i_level].upsample(h)
562
+
563
+ # end
564
+ if self.give_pre_end:
565
+ return h
566
+
567
+ h = self.norm_out(h)
568
+ h = nonlinearity(h)
569
+ h = self.conv_out(h)
570
+ return h
571
+
572
+
573
+ class RestoreFormer(nn.Module):
574
+
575
+ def __init__(self,
576
+ n_embed=1024,
577
+ embed_dim=256,
578
+ ch=64,
579
+ out_ch=3,
580
+ ch_mult=(1, 2, 2, 4, 4, 8),
581
+ num_res_blocks=2,
582
+ attn_resolutions=(16, ),
583
+ dropout=0.0,
584
+ in_channels=3,
585
+ resolution=512,
586
+ z_channels=256,
587
+ double_z=False,
588
+ enable_mid=True,
589
+ fix_decoder=False,
590
+ fix_codebook=True,
591
+ fix_encoder=False,
592
+ head_size=8):
593
+ super(RestoreFormer, self).__init__()
594
+
595
+ self.encoder = MultiHeadEncoder(
596
+ ch=ch,
597
+ out_ch=out_ch,
598
+ ch_mult=ch_mult,
599
+ num_res_blocks=num_res_blocks,
600
+ attn_resolutions=attn_resolutions,
601
+ dropout=dropout,
602
+ in_channels=in_channels,
603
+ resolution=resolution,
604
+ z_channels=z_channels,
605
+ double_z=double_z,
606
+ enable_mid=enable_mid,
607
+ head_size=head_size)
608
+ self.decoder = MultiHeadDecoderTransformer(
609
+ ch=ch,
610
+ out_ch=out_ch,
611
+ ch_mult=ch_mult,
612
+ num_res_blocks=num_res_blocks,
613
+ attn_resolutions=attn_resolutions,
614
+ dropout=dropout,
615
+ in_channels=in_channels,
616
+ resolution=resolution,
617
+ z_channels=z_channels,
618
+ enable_mid=enable_mid,
619
+ head_size=head_size)
620
+
621
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
622
+
623
+ self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
624
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
625
+
626
+ if fix_decoder:
627
+ for _, param in self.decoder.named_parameters():
628
+ param.requires_grad = False
629
+ for _, param in self.post_quant_conv.named_parameters():
630
+ param.requires_grad = False
631
+ for _, param in self.quantize.named_parameters():
632
+ param.requires_grad = False
633
+ elif fix_codebook:
634
+ for _, param in self.quantize.named_parameters():
635
+ param.requires_grad = False
636
+
637
+ if fix_encoder:
638
+ for _, param in self.encoder.named_parameters():
639
+ param.requires_grad = False
640
+
641
+ def encode(self, x):
642
+
643
+ hs = self.encoder(x)
644
+ h = self.quant_conv(hs['out'])
645
+ quant, emb_loss, info = self.quantize(h)
646
+ return quant, emb_loss, info, hs
647
+
648
+ def decode(self, quant, hs):
649
+ quant = self.post_quant_conv(quant)
650
+ dec = self.decoder(quant, hs)
651
+
652
+ return dec
653
+
654
+ def forward(self, input, **kwargs):
655
+ quant, diff, info, hs = self.encode(input)
656
+ dec = self.decode(quant, hs)
657
+
658
+ return dec, None
gfpgan/archs/stylegan2_bilinear_arch.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class NormStyleCode(nn.Module):
11
+
12
+ def forward(self, x):
13
+ """Normalize the style codes.
14
+
15
+ Args:
16
+ x (Tensor): Style codes with shape (b, c).
17
+
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class EqualLinear(nn.Module):
25
+ """Equalized Linear as StyleGAN2.
26
+
27
+ Args:
28
+ in_channels (int): Size of each sample.
29
+ out_channels (int): Size of each output sample.
30
+ bias (bool): If set to ``False``, the layer will not learn an additive
31
+ bias. Default: ``True``.
32
+ bias_init_val (float): Bias initialized value. Default: 0.
33
+ lr_mul (float): Learning rate multiplier. Default: 1.
34
+ activation (None | str): The activation after ``linear`` operation.
35
+ Supported: 'fused_lrelu', None. Default: None.
36
+ """
37
+
38
+ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
39
+ super(EqualLinear, self).__init__()
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+ self.lr_mul = lr_mul
43
+ self.activation = activation
44
+ if self.activation not in ['fused_lrelu', None]:
45
+ raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
46
+ "Supported ones are: ['fused_lrelu', None].")
47
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
48
+
49
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
50
+ if bias:
51
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
52
+ else:
53
+ self.register_parameter('bias', None)
54
+
55
+ def forward(self, x):
56
+ if self.bias is None:
57
+ bias = None
58
+ else:
59
+ bias = self.bias * self.lr_mul
60
+ if self.activation == 'fused_lrelu':
61
+ out = F.linear(x, self.weight * self.scale)
62
+ out = fused_leaky_relu(out, bias)
63
+ else:
64
+ out = F.linear(x, self.weight * self.scale, bias=bias)
65
+ return out
66
+
67
+ def __repr__(self):
68
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
69
+ f'out_channels={self.out_channels}, bias={self.bias is not None})')
70
+
71
+
72
+ class ModulatedConv2d(nn.Module):
73
+ """Modulated Conv2d used in StyleGAN2.
74
+
75
+ There is no bias in ModulatedConv2d.
76
+
77
+ Args:
78
+ in_channels (int): Channel number of the input.
79
+ out_channels (int): Channel number of the output.
80
+ kernel_size (int): Size of the convolving kernel.
81
+ num_style_feat (int): Channel number of style features.
82
+ demodulate (bool): Whether to demodulate in the conv layer.
83
+ Default: True.
84
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
85
+ Default: None.
86
+ eps (float): A value added to the denominator for numerical stability.
87
+ Default: 1e-8.
88
+ """
89
+
90
+ def __init__(self,
91
+ in_channels,
92
+ out_channels,
93
+ kernel_size,
94
+ num_style_feat,
95
+ demodulate=True,
96
+ sample_mode=None,
97
+ eps=1e-8,
98
+ interpolation_mode='bilinear'):
99
+ super(ModulatedConv2d, self).__init__()
100
+ self.in_channels = in_channels
101
+ self.out_channels = out_channels
102
+ self.kernel_size = kernel_size
103
+ self.demodulate = demodulate
104
+ self.sample_mode = sample_mode
105
+ self.eps = eps
106
+ self.interpolation_mode = interpolation_mode
107
+ if self.interpolation_mode == 'nearest':
108
+ self.align_corners = None
109
+ else:
110
+ self.align_corners = False
111
+
112
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
113
+ # modulation inside each modulated conv
114
+ self.modulation = EqualLinear(
115
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
116
+
117
+ self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
118
+ self.padding = kernel_size // 2
119
+
120
+ def forward(self, x, style):
121
+ """Forward function.
122
+
123
+ Args:
124
+ x (Tensor): Tensor with shape (b, c, h, w).
125
+ style (Tensor): Tensor with shape (b, num_style_feat).
126
+
127
+ Returns:
128
+ Tensor: Modulated tensor after convolution.
129
+ """
130
+ b, c, h, w = x.shape # c = c_in
131
+ # weight modulation
132
+ style = self.modulation(style).view(b, 1, c, 1, 1)
133
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
134
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
135
+
136
+ if self.demodulate:
137
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
138
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
139
+
140
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
141
+
142
+ if self.sample_mode == 'upsample':
143
+ x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
144
+ elif self.sample_mode == 'downsample':
145
+ x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
146
+
147
+ b, c, h, w = x.shape
148
+ x = x.view(1, b * c, h, w)
149
+ # weight: (b*c_out, c_in, k, k), groups=b
150
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
151
+ out = out.view(b, self.out_channels, *out.shape[2:4])
152
+
153
+ return out
154
+
155
+ def __repr__(self):
156
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
157
+ f'out_channels={self.out_channels}, '
158
+ f'kernel_size={self.kernel_size}, '
159
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
160
+
161
+
162
+ class StyleConv(nn.Module):
163
+ """Style conv.
164
+
165
+ Args:
166
+ in_channels (int): Channel number of the input.
167
+ out_channels (int): Channel number of the output.
168
+ kernel_size (int): Size of the convolving kernel.
169
+ num_style_feat (int): Channel number of style features.
170
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
171
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
172
+ Default: None.
173
+ """
174
+
175
+ def __init__(self,
176
+ in_channels,
177
+ out_channels,
178
+ kernel_size,
179
+ num_style_feat,
180
+ demodulate=True,
181
+ sample_mode=None,
182
+ interpolation_mode='bilinear'):
183
+ super(StyleConv, self).__init__()
184
+ self.modulated_conv = ModulatedConv2d(
185
+ in_channels,
186
+ out_channels,
187
+ kernel_size,
188
+ num_style_feat,
189
+ demodulate=demodulate,
190
+ sample_mode=sample_mode,
191
+ interpolation_mode=interpolation_mode)
192
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
193
+ self.activate = FusedLeakyReLU(out_channels)
194
+
195
+ def forward(self, x, style, noise=None):
196
+ # modulate
197
+ out = self.modulated_conv(x, style)
198
+ # noise injection
199
+ if noise is None:
200
+ b, _, h, w = out.shape
201
+ noise = out.new_empty(b, 1, h, w).normal_()
202
+ out = out + self.weight * noise
203
+ # activation (with bias)
204
+ out = self.activate(out)
205
+ return out
206
+
207
+
208
+ class ToRGB(nn.Module):
209
+ """To RGB from features.
210
+
211
+ Args:
212
+ in_channels (int): Channel number of input.
213
+ num_style_feat (int): Channel number of style features.
214
+ upsample (bool): Whether to upsample. Default: True.
215
+ """
216
+
217
+ def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
218
+ super(ToRGB, self).__init__()
219
+ self.upsample = upsample
220
+ self.interpolation_mode = interpolation_mode
221
+ if self.interpolation_mode == 'nearest':
222
+ self.align_corners = None
223
+ else:
224
+ self.align_corners = False
225
+ self.modulated_conv = ModulatedConv2d(
226
+ in_channels,
227
+ 3,
228
+ kernel_size=1,
229
+ num_style_feat=num_style_feat,
230
+ demodulate=False,
231
+ sample_mode=None,
232
+ interpolation_mode=interpolation_mode)
233
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
234
+
235
+ def forward(self, x, style, skip=None):
236
+ """Forward function.
237
+
238
+ Args:
239
+ x (Tensor): Feature tensor with shape (b, c, h, w).
240
+ style (Tensor): Tensor with shape (b, num_style_feat).
241
+ skip (Tensor): Base/skip tensor. Default: None.
242
+
243
+ Returns:
244
+ Tensor: RGB images.
245
+ """
246
+ out = self.modulated_conv(x, style)
247
+ out = out + self.bias
248
+ if skip is not None:
249
+ if self.upsample:
250
+ skip = F.interpolate(
251
+ skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
252
+ out = out + skip
253
+ return out
254
+
255
+
256
+ class ConstantInput(nn.Module):
257
+ """Constant input.
258
+
259
+ Args:
260
+ num_channel (int): Channel number of constant input.
261
+ size (int): Spatial size of constant input.
262
+ """
263
+
264
+ def __init__(self, num_channel, size):
265
+ super(ConstantInput, self).__init__()
266
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
267
+
268
+ def forward(self, batch):
269
+ out = self.weight.repeat(batch, 1, 1, 1)
270
+ return out
271
+
272
+
273
+ @ARCH_REGISTRY.register()
274
+ class StyleGAN2GeneratorBilinear(nn.Module):
275
+ """StyleGAN2 Generator.
276
+
277
+ Args:
278
+ out_size (int): The spatial size of outputs.
279
+ num_style_feat (int): Channel number of style features. Default: 512.
280
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
281
+ channel_multiplier (int): Channel multiplier for large networks of
282
+ StyleGAN2. Default: 2.
283
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
284
+ narrow (float): Narrow ratio for channels. Default: 1.0.
285
+ """
286
+
287
+ def __init__(self,
288
+ out_size,
289
+ num_style_feat=512,
290
+ num_mlp=8,
291
+ channel_multiplier=2,
292
+ lr_mlp=0.01,
293
+ narrow=1,
294
+ interpolation_mode='bilinear'):
295
+ super(StyleGAN2GeneratorBilinear, self).__init__()
296
+ # Style MLP layers
297
+ self.num_style_feat = num_style_feat
298
+ style_mlp_layers = [NormStyleCode()]
299
+ for i in range(num_mlp):
300
+ style_mlp_layers.append(
301
+ EqualLinear(
302
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
303
+ activation='fused_lrelu'))
304
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
305
+
306
+ channels = {
307
+ '4': int(512 * narrow),
308
+ '8': int(512 * narrow),
309
+ '16': int(512 * narrow),
310
+ '32': int(512 * narrow),
311
+ '64': int(256 * channel_multiplier * narrow),
312
+ '128': int(128 * channel_multiplier * narrow),
313
+ '256': int(64 * channel_multiplier * narrow),
314
+ '512': int(32 * channel_multiplier * narrow),
315
+ '1024': int(16 * channel_multiplier * narrow)
316
+ }
317
+ self.channels = channels
318
+
319
+ self.constant_input = ConstantInput(channels['4'], size=4)
320
+ self.style_conv1 = StyleConv(
321
+ channels['4'],
322
+ channels['4'],
323
+ kernel_size=3,
324
+ num_style_feat=num_style_feat,
325
+ demodulate=True,
326
+ sample_mode=None,
327
+ interpolation_mode=interpolation_mode)
328
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
329
+
330
+ self.log_size = int(math.log(out_size, 2))
331
+ self.num_layers = (self.log_size - 2) * 2 + 1
332
+ self.num_latent = self.log_size * 2 - 2
333
+
334
+ self.style_convs = nn.ModuleList()
335
+ self.to_rgbs = nn.ModuleList()
336
+ self.noises = nn.Module()
337
+
338
+ in_channels = channels['4']
339
+ # noise
340
+ for layer_idx in range(self.num_layers):
341
+ resolution = 2**((layer_idx + 5) // 2)
342
+ shape = [1, 1, resolution, resolution]
343
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
344
+ # style convs and to_rgbs
345
+ for i in range(3, self.log_size + 1):
346
+ out_channels = channels[f'{2**i}']
347
+ self.style_convs.append(
348
+ StyleConv(
349
+ in_channels,
350
+ out_channels,
351
+ kernel_size=3,
352
+ num_style_feat=num_style_feat,
353
+ demodulate=True,
354
+ sample_mode='upsample',
355
+ interpolation_mode=interpolation_mode))
356
+ self.style_convs.append(
357
+ StyleConv(
358
+ out_channels,
359
+ out_channels,
360
+ kernel_size=3,
361
+ num_style_feat=num_style_feat,
362
+ demodulate=True,
363
+ sample_mode=None,
364
+ interpolation_mode=interpolation_mode))
365
+ self.to_rgbs.append(
366
+ ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
367
+ in_channels = out_channels
368
+
369
+ def make_noise(self):
370
+ """Make noise for noise injection."""
371
+ device = self.constant_input.weight.device
372
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
373
+
374
+ for i in range(3, self.log_size + 1):
375
+ for _ in range(2):
376
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
377
+
378
+ return noises
379
+
380
+ def get_latent(self, x):
381
+ return self.style_mlp(x)
382
+
383
+ def mean_latent(self, num_latent):
384
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
385
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
386
+ return latent
387
+
388
+ def forward(self,
389
+ styles,
390
+ input_is_latent=False,
391
+ noise=None,
392
+ randomize_noise=True,
393
+ truncation=1,
394
+ truncation_latent=None,
395
+ inject_index=None,
396
+ return_latents=False):
397
+ """Forward function for StyleGAN2Generator.
398
+
399
+ Args:
400
+ styles (list[Tensor]): Sample codes of styles.
401
+ input_is_latent (bool): Whether input is latent style.
402
+ Default: False.
403
+ noise (Tensor | None): Input noise or None. Default: None.
404
+ randomize_noise (bool): Randomize noise, used when 'noise' is
405
+ False. Default: True.
406
+ truncation (float): TODO. Default: 1.
407
+ truncation_latent (Tensor | None): TODO. Default: None.
408
+ inject_index (int | None): The injection index for mixing noise.
409
+ Default: None.
410
+ return_latents (bool): Whether to return style latents.
411
+ Default: False.
412
+ """
413
+ # style codes -> latents with Style MLP layer
414
+ if not input_is_latent:
415
+ styles = [self.style_mlp(s) for s in styles]
416
+ # noises
417
+ if noise is None:
418
+ if randomize_noise:
419
+ noise = [None] * self.num_layers # for each style conv layer
420
+ else: # use the stored noise
421
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
422
+ # style truncation
423
+ if truncation < 1:
424
+ style_truncation = []
425
+ for style in styles:
426
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
427
+ styles = style_truncation
428
+ # get style latent with injection
429
+ if len(styles) == 1:
430
+ inject_index = self.num_latent
431
+
432
+ if styles[0].ndim < 3:
433
+ # repeat latent code for all the layers
434
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
435
+ else: # used for encoder with different latent code for each layer
436
+ latent = styles[0]
437
+ elif len(styles) == 2: # mixing noises
438
+ if inject_index is None:
439
+ inject_index = random.randint(1, self.num_latent - 1)
440
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
441
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
442
+ latent = torch.cat([latent1, latent2], 1)
443
+
444
+ # main generation
445
+ out = self.constant_input(latent.shape[0])
446
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
447
+ skip = self.to_rgb1(out, latent[:, 1])
448
+
449
+ i = 1
450
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
451
+ noise[2::2], self.to_rgbs):
452
+ out = conv1(out, latent[:, i], noise=noise1)
453
+ out = conv2(out, latent[:, i + 1], noise=noise2)
454
+ skip = to_rgb(out, latent[:, i + 2], skip)
455
+ i += 2
456
+
457
+ image = skip
458
+
459
+ if return_latents:
460
+ return image, latent
461
+ else:
462
+ return image, None
463
+
464
+
465
+ class ScaledLeakyReLU(nn.Module):
466
+ """Scaled LeakyReLU.
467
+
468
+ Args:
469
+ negative_slope (float): Negative slope. Default: 0.2.
470
+ """
471
+
472
+ def __init__(self, negative_slope=0.2):
473
+ super(ScaledLeakyReLU, self).__init__()
474
+ self.negative_slope = negative_slope
475
+
476
+ def forward(self, x):
477
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
478
+ return out * math.sqrt(2)
479
+
480
+
481
+ class EqualConv2d(nn.Module):
482
+ """Equalized Linear as StyleGAN2.
483
+
484
+ Args:
485
+ in_channels (int): Channel number of the input.
486
+ out_channels (int): Channel number of the output.
487
+ kernel_size (int): Size of the convolving kernel.
488
+ stride (int): Stride of the convolution. Default: 1
489
+ padding (int): Zero-padding added to both sides of the input.
490
+ Default: 0.
491
+ bias (bool): If ``True``, adds a learnable bias to the output.
492
+ Default: ``True``.
493
+ bias_init_val (float): Bias initialized value. Default: 0.
494
+ """
495
+
496
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
497
+ super(EqualConv2d, self).__init__()
498
+ self.in_channels = in_channels
499
+ self.out_channels = out_channels
500
+ self.kernel_size = kernel_size
501
+ self.stride = stride
502
+ self.padding = padding
503
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
504
+
505
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
506
+ if bias:
507
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
508
+ else:
509
+ self.register_parameter('bias', None)
510
+
511
+ def forward(self, x):
512
+ out = F.conv2d(
513
+ x,
514
+ self.weight * self.scale,
515
+ bias=self.bias,
516
+ stride=self.stride,
517
+ padding=self.padding,
518
+ )
519
+
520
+ return out
521
+
522
+ def __repr__(self):
523
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
524
+ f'out_channels={self.out_channels}, '
525
+ f'kernel_size={self.kernel_size},'
526
+ f' stride={self.stride}, padding={self.padding}, '
527
+ f'bias={self.bias is not None})')
528
+
529
+
530
+ class ConvLayer(nn.Sequential):
531
+ """Conv Layer used in StyleGAN2 Discriminator.
532
+
533
+ Args:
534
+ in_channels (int): Channel number of the input.
535
+ out_channels (int): Channel number of the output.
536
+ kernel_size (int): Kernel size.
537
+ downsample (bool): Whether downsample by a factor of 2.
538
+ Default: False.
539
+ bias (bool): Whether with bias. Default: True.
540
+ activate (bool): Whether use activateion. Default: True.
541
+ """
542
+
543
+ def __init__(self,
544
+ in_channels,
545
+ out_channels,
546
+ kernel_size,
547
+ downsample=False,
548
+ bias=True,
549
+ activate=True,
550
+ interpolation_mode='bilinear'):
551
+ layers = []
552
+ self.interpolation_mode = interpolation_mode
553
+ # downsample
554
+ if downsample:
555
+ if self.interpolation_mode == 'nearest':
556
+ self.align_corners = None
557
+ else:
558
+ self.align_corners = False
559
+
560
+ layers.append(
561
+ torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
562
+ stride = 1
563
+ self.padding = kernel_size // 2
564
+ # conv
565
+ layers.append(
566
+ EqualConv2d(
567
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
568
+ and not activate))
569
+ # activation
570
+ if activate:
571
+ if bias:
572
+ layers.append(FusedLeakyReLU(out_channels))
573
+ else:
574
+ layers.append(ScaledLeakyReLU(0.2))
575
+
576
+ super(ConvLayer, self).__init__(*layers)
577
+
578
+
579
+ class ResBlock(nn.Module):
580
+ """Residual block used in StyleGAN2 Discriminator.
581
+
582
+ Args:
583
+ in_channels (int): Channel number of the input.
584
+ out_channels (int): Channel number of the output.
585
+ """
586
+
587
+ def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
588
+ super(ResBlock, self).__init__()
589
+
590
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
591
+ self.conv2 = ConvLayer(
592
+ in_channels,
593
+ out_channels,
594
+ 3,
595
+ downsample=True,
596
+ interpolation_mode=interpolation_mode,
597
+ bias=True,
598
+ activate=True)
599
+ self.skip = ConvLayer(
600
+ in_channels,
601
+ out_channels,
602
+ 1,
603
+ downsample=True,
604
+ interpolation_mode=interpolation_mode,
605
+ bias=False,
606
+ activate=False)
607
+
608
+ def forward(self, x):
609
+ out = self.conv1(x)
610
+ out = self.conv2(out)
611
+ skip = self.skip(x)
612
+ out = (out + skip) / math.sqrt(2)
613
+ return out
gfpgan/archs/stylegan2_clean_arch.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.archs.arch_util import default_init_weights
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class NormStyleCode(nn.Module):
11
+
12
+ def forward(self, x):
13
+ """Normalize the style codes.
14
+
15
+ Args:
16
+ x (Tensor): Style codes with shape (b, c).
17
+
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class ModulatedConv2d(nn.Module):
25
+ """Modulated Conv2d used in StyleGAN2.
26
+
27
+ There is no bias in ModulatedConv2d.
28
+
29
+ Args:
30
+ in_channels (int): Channel number of the input.
31
+ out_channels (int): Channel number of the output.
32
+ kernel_size (int): Size of the convolving kernel.
33
+ num_style_feat (int): Channel number of style features.
34
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
35
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
36
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
37
+ """
38
+
39
+ def __init__(self,
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ num_style_feat,
44
+ demodulate=True,
45
+ sample_mode=None,
46
+ eps=1e-8):
47
+ super(ModulatedConv2d, self).__init__()
48
+ self.in_channels = in_channels
49
+ self.out_channels = out_channels
50
+ self.kernel_size = kernel_size
51
+ self.demodulate = demodulate
52
+ self.sample_mode = sample_mode
53
+ self.eps = eps
54
+
55
+ # modulation inside each modulated conv
56
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
57
+ # initialization
58
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
59
+
60
+ self.weight = nn.Parameter(
61
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
62
+ math.sqrt(in_channels * kernel_size**2))
63
+ self.padding = kernel_size // 2
64
+
65
+ def forward(self, x, style):
66
+ """Forward function.
67
+
68
+ Args:
69
+ x (Tensor): Tensor with shape (b, c, h, w).
70
+ style (Tensor): Tensor with shape (b, num_style_feat).
71
+
72
+ Returns:
73
+ Tensor: Modulated tensor after convolution.
74
+ """
75
+ b, c, h, w = x.shape # c = c_in
76
+ # weight modulation
77
+ style = self.modulation(style).view(b, 1, c, 1, 1)
78
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
79
+ weight = self.weight * style # (b, c_out, c_in, k, k)
80
+
81
+ if self.demodulate:
82
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
83
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
84
+
85
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
86
+
87
+ # upsample or downsample if necessary
88
+ if self.sample_mode == 'upsample':
89
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
90
+ elif self.sample_mode == 'downsample':
91
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
92
+
93
+ b, c, h, w = x.shape
94
+ x = x.view(1, b * c, h, w)
95
+ # weight: (b*c_out, c_in, k, k), groups=b
96
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
97
+ out = out.view(b, self.out_channels, *out.shape[2:4])
98
+
99
+ return out
100
+
101
+ def __repr__(self):
102
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
103
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
104
+
105
+
106
+ class StyleConv(nn.Module):
107
+ """Style conv used in StyleGAN2.
108
+
109
+ Args:
110
+ in_channels (int): Channel number of the input.
111
+ out_channels (int): Channel number of the output.
112
+ kernel_size (int): Size of the convolving kernel.
113
+ num_style_feat (int): Channel number of style features.
114
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
115
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
116
+ """
117
+
118
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
119
+ super(StyleConv, self).__init__()
120
+ self.modulated_conv = ModulatedConv2d(
121
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
122
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
123
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
124
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
125
+
126
+ def forward(self, x, style, noise=None):
127
+ # modulate
128
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
129
+ # noise injection
130
+ if noise is None:
131
+ b, _, h, w = out.shape
132
+ noise = out.new_empty(b, 1, h, w).normal_()
133
+ out = out + self.weight * noise
134
+ # add bias
135
+ out = out + self.bias
136
+ # activation
137
+ out = self.activate(out)
138
+ return out
139
+
140
+
141
+ class ToRGB(nn.Module):
142
+ """To RGB (image space) from features.
143
+
144
+ Args:
145
+ in_channels (int): Channel number of input.
146
+ num_style_feat (int): Channel number of style features.
147
+ upsample (bool): Whether to upsample. Default: True.
148
+ """
149
+
150
+ def __init__(self, in_channels, num_style_feat, upsample=True):
151
+ super(ToRGB, self).__init__()
152
+ self.upsample = upsample
153
+ self.modulated_conv = ModulatedConv2d(
154
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
155
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
156
+
157
+ def forward(self, x, style, skip=None):
158
+ """Forward function.
159
+
160
+ Args:
161
+ x (Tensor): Feature tensor with shape (b, c, h, w).
162
+ style (Tensor): Tensor with shape (b, num_style_feat).
163
+ skip (Tensor): Base/skip tensor. Default: None.
164
+
165
+ Returns:
166
+ Tensor: RGB images.
167
+ """
168
+ out = self.modulated_conv(x, style)
169
+ out = out + self.bias
170
+ if skip is not None:
171
+ if self.upsample:
172
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
173
+ out = out + skip
174
+ return out
175
+
176
+
177
+ class ConstantInput(nn.Module):
178
+ """Constant input.
179
+
180
+ Args:
181
+ num_channel (int): Channel number of constant input.
182
+ size (int): Spatial size of constant input.
183
+ """
184
+
185
+ def __init__(self, num_channel, size):
186
+ super(ConstantInput, self).__init__()
187
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
188
+
189
+ def forward(self, batch):
190
+ out = self.weight.repeat(batch, 1, 1, 1)
191
+ return out
192
+
193
+
194
+ @ARCH_REGISTRY.register()
195
+ class StyleGAN2GeneratorClean(nn.Module):
196
+ """Clean version of StyleGAN2 Generator.
197
+
198
+ Args:
199
+ out_size (int): The spatial size of outputs.
200
+ num_style_feat (int): Channel number of style features. Default: 512.
201
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
202
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
203
+ narrow (float): Narrow ratio for channels. Default: 1.0.
204
+ """
205
+
206
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
207
+ super(StyleGAN2GeneratorClean, self).__init__()
208
+ # Style MLP layers
209
+ self.num_style_feat = num_style_feat
210
+ style_mlp_layers = [NormStyleCode()]
211
+ for i in range(num_mlp):
212
+ style_mlp_layers.extend(
213
+ [nn.Linear(num_style_feat, num_style_feat, bias=True),
214
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)])
215
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
216
+ # initialization
217
+ default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
218
+
219
+ # channel list
220
+ channels = {
221
+ '4': int(512 * narrow),
222
+ '8': int(512 * narrow),
223
+ '16': int(512 * narrow),
224
+ '32': int(512 * narrow),
225
+ '64': int(256 * channel_multiplier * narrow),
226
+ '128': int(128 * channel_multiplier * narrow),
227
+ '256': int(64 * channel_multiplier * narrow),
228
+ '512': int(32 * channel_multiplier * narrow),
229
+ '1024': int(16 * channel_multiplier * narrow)
230
+ }
231
+ self.channels = channels
232
+
233
+ self.constant_input = ConstantInput(channels['4'], size=4)
234
+ self.style_conv1 = StyleConv(
235
+ channels['4'],
236
+ channels['4'],
237
+ kernel_size=3,
238
+ num_style_feat=num_style_feat,
239
+ demodulate=True,
240
+ sample_mode=None)
241
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
242
+
243
+ self.log_size = int(math.log(out_size, 2))
244
+ self.num_layers = (self.log_size - 2) * 2 + 1
245
+ self.num_latent = self.log_size * 2 - 2
246
+
247
+ self.style_convs = nn.ModuleList()
248
+ self.to_rgbs = nn.ModuleList()
249
+ self.noises = nn.Module()
250
+
251
+ in_channels = channels['4']
252
+ # noise
253
+ for layer_idx in range(self.num_layers):
254
+ resolution = 2**((layer_idx + 5) // 2)
255
+ shape = [1, 1, resolution, resolution]
256
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
257
+ # style convs and to_rgbs
258
+ for i in range(3, self.log_size + 1):
259
+ out_channels = channels[f'{2**i}']
260
+ self.style_convs.append(
261
+ StyleConv(
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size=3,
265
+ num_style_feat=num_style_feat,
266
+ demodulate=True,
267
+ sample_mode='upsample'))
268
+ self.style_convs.append(
269
+ StyleConv(
270
+ out_channels,
271
+ out_channels,
272
+ kernel_size=3,
273
+ num_style_feat=num_style_feat,
274
+ demodulate=True,
275
+ sample_mode=None))
276
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
277
+ in_channels = out_channels
278
+
279
+ def make_noise(self):
280
+ """Make noise for noise injection."""
281
+ device = self.constant_input.weight.device
282
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
283
+
284
+ for i in range(3, self.log_size + 1):
285
+ for _ in range(2):
286
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
287
+
288
+ return noises
289
+
290
+ def get_latent(self, x):
291
+ return self.style_mlp(x)
292
+
293
+ def mean_latent(self, num_latent):
294
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
295
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
296
+ return latent
297
+
298
+ def forward(self,
299
+ styles,
300
+ input_is_latent=False,
301
+ noise=None,
302
+ randomize_noise=True,
303
+ truncation=1,
304
+ truncation_latent=None,
305
+ inject_index=None,
306
+ return_latents=False):
307
+ """Forward function for StyleGAN2GeneratorClean.
308
+
309
+ Args:
310
+ styles (list[Tensor]): Sample codes of styles.
311
+ input_is_latent (bool): Whether input is latent style. Default: False.
312
+ noise (Tensor | None): Input noise or None. Default: None.
313
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
314
+ truncation (float): The truncation ratio. Default: 1.
315
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
316
+ inject_index (int | None): The injection index for mixing noise. Default: None.
317
+ return_latents (bool): Whether to return style latents. Default: False.
318
+ """
319
+ # style codes -> latents with Style MLP layer
320
+ if not input_is_latent:
321
+ styles = [self.style_mlp(s) for s in styles]
322
+ # noises
323
+ if noise is None:
324
+ if randomize_noise:
325
+ noise = [None] * self.num_layers # for each style conv layer
326
+ else: # use the stored noise
327
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
328
+ # style truncation
329
+ if truncation < 1:
330
+ style_truncation = []
331
+ for style in styles:
332
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
333
+ styles = style_truncation
334
+ # get style latents with injection
335
+ if len(styles) == 1:
336
+ inject_index = self.num_latent
337
+
338
+ if styles[0].ndim < 3:
339
+ # repeat latent code for all the layers
340
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
341
+ else: # used for encoder with different latent code for each layer
342
+ latent = styles[0]
343
+ elif len(styles) == 2: # mixing noises
344
+ if inject_index is None:
345
+ inject_index = random.randint(1, self.num_latent - 1)
346
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
347
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
348
+ latent = torch.cat([latent1, latent2], 1)
349
+
350
+ # main generation
351
+ out = self.constant_input(latent.shape[0])
352
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
353
+ skip = self.to_rgb1(out, latent[:, 1])
354
+
355
+ i = 1
356
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
357
+ noise[2::2], self.to_rgbs):
358
+ out = conv1(out, latent[:, i], noise=noise1)
359
+ out = conv2(out, latent[:, i + 1], noise=noise2)
360
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
361
+ i += 2
362
+
363
+ image = skip
364
+
365
+ if return_latents:
366
+ return image, latent
367
+ else:
368
+ return image, None
gfpgan/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import dataset modules for registry
6
+ # scan all the files that end with '_dataset.py' under the data folder
7
+ data_folder = osp.dirname(osp.abspath(__file__))
8
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9
+ # import all the dataset modules
10
+ _dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
gfpgan/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (893 Bytes). View file
 
gfpgan/data/__pycache__/ffhq_degradation_dataset.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
gfpgan/data/ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os.path as osp
5
+ import torch
6
+ import torch.utils.data as data
7
+ from basicsr.data import degradations as degradations
8
+ from basicsr.data.data_util import paths_from_folder
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
11
+ from basicsr.utils.registry import DATASET_REGISTRY
12
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
13
+ normalize)
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class FFHQDegradationDataset(data.Dataset):
18
+ """FFHQ dataset for GFPGAN.
19
+
20
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
21
+
22
+ Args:
23
+ opt (dict): Config for train datasets. It contains the following keys:
24
+ dataroot_gt (str): Data root path for gt.
25
+ io_backend (dict): IO backend type and other kwarg.
26
+ mean (list | tuple): Image mean.
27
+ std (list | tuple): Image std.
28
+ use_hflip (bool): Whether to horizontally flip.
29
+ Please see more options in the codes.
30
+ """
31
+
32
+ def __init__(self, opt):
33
+ super(FFHQDegradationDataset, self).__init__()
34
+ self.opt = opt
35
+ # file client (io backend)
36
+ self.file_client = None
37
+ self.io_backend_opt = opt['io_backend']
38
+
39
+ self.gt_folder = opt['dataroot_gt']
40
+ self.mean = opt['mean']
41
+ self.std = opt['std']
42
+ self.out_size = opt['out_size']
43
+
44
+ self.crop_components = opt.get('crop_components', False) # facial components
45
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
46
+
47
+ if self.crop_components:
48
+ # load component list from a pre-process pth files
49
+ self.components_list = torch.load(opt.get('component_path'))
50
+
51
+ # file client (lmdb io backend)
52
+ if self.io_backend_opt['type'] == 'lmdb':
53
+ self.io_backend_opt['db_paths'] = self.gt_folder
54
+ if not self.gt_folder.endswith('.lmdb'):
55
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
56
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
57
+ self.paths = [line.split('.')[0] for line in fin]
58
+ else:
59
+ # disk backend: scan file list from a folder
60
+ self.paths = paths_from_folder(self.gt_folder)
61
+
62
+ # degradation configurations
63
+ self.blur_kernel_size = opt['blur_kernel_size']
64
+ self.kernel_list = opt['kernel_list']
65
+ self.kernel_prob = opt['kernel_prob']
66
+ self.blur_sigma = opt['blur_sigma']
67
+ self.downsample_range = opt['downsample_range']
68
+ self.noise_range = opt['noise_range']
69
+ self.jpeg_range = opt['jpeg_range']
70
+
71
+ # color jitter
72
+ self.color_jitter_prob = opt.get('color_jitter_prob')
73
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
74
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
75
+ # to gray
76
+ self.gray_prob = opt.get('gray_prob')
77
+
78
+ logger = get_root_logger()
79
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
80
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
81
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
82
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
83
+
84
+ if self.color_jitter_prob is not None:
85
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
86
+ if self.gray_prob is not None:
87
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
88
+ self.color_jitter_shift /= 255.
89
+
90
+ @staticmethod
91
+ def color_jitter(img, shift):
92
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
93
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
94
+ img = img + jitter_val
95
+ img = np.clip(img, 0, 1)
96
+ return img
97
+
98
+ @staticmethod
99
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
100
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
101
+ fn_idx = torch.randperm(4)
102
+ for fn_id in fn_idx:
103
+ if fn_id == 0 and brightness is not None:
104
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
105
+ img = adjust_brightness(img, brightness_factor)
106
+
107
+ if fn_id == 1 and contrast is not None:
108
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
109
+ img = adjust_contrast(img, contrast_factor)
110
+
111
+ if fn_id == 2 and saturation is not None:
112
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
113
+ img = adjust_saturation(img, saturation_factor)
114
+
115
+ if fn_id == 3 and hue is not None:
116
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
117
+ img = adjust_hue(img, hue_factor)
118
+ return img
119
+
120
+ def get_component_coordinates(self, index, status):
121
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
122
+ components_bbox = self.components_list[f'{index:08d}']
123
+ if status[0]: # hflip
124
+ # exchange right and left eye
125
+ tmp = components_bbox['left_eye']
126
+ components_bbox['left_eye'] = components_bbox['right_eye']
127
+ components_bbox['right_eye'] = tmp
128
+ # modify the width coordinate
129
+ components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
130
+ components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
131
+ components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
132
+
133
+ # get coordinates
134
+ locations = []
135
+ for part in ['left_eye', 'right_eye', 'mouth']:
136
+ mean = components_bbox[part][0:2]
137
+ half_len = components_bbox[part][2]
138
+ if 'eye' in part:
139
+ half_len *= self.eye_enlarge_ratio
140
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
141
+ loc = torch.from_numpy(loc).float()
142
+ locations.append(loc)
143
+ return locations
144
+
145
+ def __getitem__(self, index):
146
+ if self.file_client is None:
147
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
+
149
+ # load gt image
150
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
151
+ gt_path = self.paths[index]
152
+ img_bytes = self.file_client.get(gt_path)
153
+ img_gt = imfrombytes(img_bytes, float32=True)
154
+
155
+ # random horizontal flip
156
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
157
+ h, w, _ = img_gt.shape
158
+
159
+ # get facial component coordinates
160
+ if self.crop_components:
161
+ locations = self.get_component_coordinates(index, status)
162
+ loc_left_eye, loc_right_eye, loc_mouth = locations
163
+
164
+ # ------------------------ generate lq image ------------------------ #
165
+ # blur
166
+ kernel = degradations.random_mixed_kernels(
167
+ self.kernel_list,
168
+ self.kernel_prob,
169
+ self.blur_kernel_size,
170
+ self.blur_sigma,
171
+ self.blur_sigma, [-math.pi, math.pi],
172
+ noise_range=None)
173
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
174
+ # downsample
175
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
176
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
177
+ # noise
178
+ if self.noise_range is not None:
179
+ img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
180
+ # jpeg compression
181
+ if self.jpeg_range is not None:
182
+ img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
183
+
184
+ # resize to original size
185
+ img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
186
+
187
+ # random color jitter (only for lq)
188
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
189
+ img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
190
+ # random to gray (only for lq)
191
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
192
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
193
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
194
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
195
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
196
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
197
+
198
+ # BGR to RGB, HWC to CHW, numpy to tensor
199
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
200
+
201
+ # random color jitter (pytorch version) (only for lq)
202
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
203
+ brightness = self.opt.get('brightness', (0.5, 1.5))
204
+ contrast = self.opt.get('contrast', (0.5, 1.5))
205
+ saturation = self.opt.get('saturation', (0, 1.5))
206
+ hue = self.opt.get('hue', (-0.1, 0.1))
207
+ img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
208
+
209
+ # round and clip
210
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
211
+
212
+ # normalize
213
+ normalize(img_gt, self.mean, self.std, inplace=True)
214
+ normalize(img_lq, self.mean, self.std, inplace=True)
215
+
216
+ if self.crop_components:
217
+ return_dict = {
218
+ 'lq': img_lq,
219
+ 'gt': img_gt,
220
+ 'gt_path': gt_path,
221
+ 'loc_left_eye': loc_left_eye,
222
+ 'loc_right_eye': loc_right_eye,
223
+ 'loc_mouth': loc_mouth
224
+ }
225
+ return return_dict
226
+ else:
227
+ return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
228
+
229
+ def __len__(self):
230
+ return len(self.paths)
gfpgan/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import model modules for registry
6
+ # scan all the files that end with '_model.py' under the model folder
7
+ model_folder = osp.dirname(osp.abspath(__file__))
8
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9
+ # import all the model modules
10
+ _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
gfpgan/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (890 Bytes). View file
 
gfpgan/models/__pycache__/gfpgan_model.cpython-312.pyc ADDED
Binary file (31.5 kB). View file
 
gfpgan/models/gfpgan_model.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os.path as osp
3
+ import torch
4
+ from basicsr.archs import build_network
5
+ from basicsr.losses import build_loss
6
+ from basicsr.losses.gan_loss import r1_penalty
7
+ from basicsr.metrics import calculate_metric
8
+ from basicsr.models.base_model import BaseModel
9
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from basicsr.utils.registry import MODEL_REGISTRY
11
+ from collections import OrderedDict
12
+ from torch.nn import functional as F
13
+ from torchvision.ops import roi_align
14
+ from tqdm import tqdm
15
+
16
+
17
+ @MODEL_REGISTRY.register()
18
+ class GFPGANModel(BaseModel):
19
+ """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
20
+
21
+ def __init__(self, opt):
22
+ super(GFPGANModel, self).__init__(opt)
23
+ self.idx = 0 # it is used for saving data for check
24
+
25
+ # define network
26
+ self.net_g = build_network(opt['network_g'])
27
+ self.net_g = self.model_to_device(self.net_g)
28
+ self.print_network(self.net_g)
29
+
30
+ # load pretrained model
31
+ load_path = self.opt['path'].get('pretrain_network_g', None)
32
+ if load_path is not None:
33
+ param_key = self.opt['path'].get('param_key_g', 'params')
34
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
35
+
36
+ self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
37
+
38
+ if self.is_train:
39
+ self.init_training_settings()
40
+
41
+ def init_training_settings(self):
42
+ train_opt = self.opt['train']
43
+
44
+ # ----------- define net_d ----------- #
45
+ self.net_d = build_network(self.opt['network_d'])
46
+ self.net_d = self.model_to_device(self.net_d)
47
+ self.print_network(self.net_d)
48
+ # load pretrained model
49
+ load_path = self.opt['path'].get('pretrain_network_d', None)
50
+ if load_path is not None:
51
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52
+
53
+ # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54
+ # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
55
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
56
+ # load pretrained model
57
+ load_path = self.opt['path'].get('pretrain_network_g', None)
58
+ if load_path is not None:
59
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
60
+ else:
61
+ self.model_ema(0) # copy net_g weight
62
+
63
+ self.net_g.train()
64
+ self.net_d.train()
65
+ self.net_g_ema.eval()
66
+
67
+ # ----------- facial component networks ----------- #
68
+ if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
69
+ self.use_facial_disc = True
70
+ else:
71
+ self.use_facial_disc = False
72
+
73
+ if self.use_facial_disc:
74
+ # left eye
75
+ self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
76
+ self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
77
+ self.print_network(self.net_d_left_eye)
78
+ load_path = self.opt['path'].get('pretrain_network_d_left_eye')
79
+ if load_path is not None:
80
+ self.load_network(self.net_d_left_eye, load_path, True, 'params')
81
+ # right eye
82
+ self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
83
+ self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
84
+ self.print_network(self.net_d_right_eye)
85
+ load_path = self.opt['path'].get('pretrain_network_d_right_eye')
86
+ if load_path is not None:
87
+ self.load_network(self.net_d_right_eye, load_path, True, 'params')
88
+ # mouth
89
+ self.net_d_mouth = build_network(self.opt['network_d_mouth'])
90
+ self.net_d_mouth = self.model_to_device(self.net_d_mouth)
91
+ self.print_network(self.net_d_mouth)
92
+ load_path = self.opt['path'].get('pretrain_network_d_mouth')
93
+ if load_path is not None:
94
+ self.load_network(self.net_d_mouth, load_path, True, 'params')
95
+
96
+ self.net_d_left_eye.train()
97
+ self.net_d_right_eye.train()
98
+ self.net_d_mouth.train()
99
+
100
+ # ----------- define facial component gan loss ----------- #
101
+ self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
102
+
103
+ # ----------- define losses ----------- #
104
+ # pixel loss
105
+ if train_opt.get('pixel_opt'):
106
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107
+ else:
108
+ self.cri_pix = None
109
+
110
+ # perceptual loss
111
+ if train_opt.get('perceptual_opt'):
112
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
113
+ else:
114
+ self.cri_perceptual = None
115
+
116
+ # L1 loss is used in pyramid loss, component style loss and identity loss
117
+ self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
118
+
119
+ # gan loss (wgan)
120
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
121
+
122
+ # ----------- define identity loss ----------- #
123
+ if 'network_identity' in self.opt:
124
+ self.use_identity = True
125
+ else:
126
+ self.use_identity = False
127
+
128
+ if self.use_identity:
129
+ # define identity network
130
+ self.network_identity = build_network(self.opt['network_identity'])
131
+ self.network_identity = self.model_to_device(self.network_identity)
132
+ self.print_network(self.network_identity)
133
+ load_path = self.opt['path'].get('pretrain_network_identity')
134
+ if load_path is not None:
135
+ self.load_network(self.network_identity, load_path, True, None)
136
+ self.network_identity.eval()
137
+ for param in self.network_identity.parameters():
138
+ param.requires_grad = False
139
+
140
+ # regularization weights
141
+ self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
142
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
143
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
144
+ self.net_d_reg_every = train_opt['net_d_reg_every']
145
+
146
+ # set up optimizers and schedulers
147
+ self.setup_optimizers()
148
+ self.setup_schedulers()
149
+
150
+ def setup_optimizers(self):
151
+ train_opt = self.opt['train']
152
+
153
+ # ----------- optimizer g ----------- #
154
+ net_g_reg_ratio = 1
155
+ normal_params = []
156
+ for _, param in self.net_g.named_parameters():
157
+ normal_params.append(param)
158
+ optim_params_g = [{ # add normal params first
159
+ 'params': normal_params,
160
+ 'lr': train_opt['optim_g']['lr']
161
+ }]
162
+ optim_type = train_opt['optim_g'].pop('type')
163
+ lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
164
+ betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
165
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
166
+ self.optimizers.append(self.optimizer_g)
167
+
168
+ # ----------- optimizer d ----------- #
169
+ net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
170
+ normal_params = []
171
+ for _, param in self.net_d.named_parameters():
172
+ normal_params.append(param)
173
+ optim_params_d = [{ # add normal params first
174
+ 'params': normal_params,
175
+ 'lr': train_opt['optim_d']['lr']
176
+ }]
177
+ optim_type = train_opt['optim_d'].pop('type')
178
+ lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
179
+ betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
180
+ self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
181
+ self.optimizers.append(self.optimizer_d)
182
+
183
+ # ----------- optimizers for facial component networks ----------- #
184
+ if self.use_facial_disc:
185
+ # setup optimizers for facial component discriminators
186
+ optim_type = train_opt['optim_component'].pop('type')
187
+ lr = train_opt['optim_component']['lr']
188
+ # left eye
189
+ self.optimizer_d_left_eye = self.get_optimizer(
190
+ optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
191
+ self.optimizers.append(self.optimizer_d_left_eye)
192
+ # right eye
193
+ self.optimizer_d_right_eye = self.get_optimizer(
194
+ optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
195
+ self.optimizers.append(self.optimizer_d_right_eye)
196
+ # mouth
197
+ self.optimizer_d_mouth = self.get_optimizer(
198
+ optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
199
+ self.optimizers.append(self.optimizer_d_mouth)
200
+
201
+ def feed_data(self, data):
202
+ self.lq = data['lq'].to(self.device)
203
+ if 'gt' in data:
204
+ self.gt = data['gt'].to(self.device)
205
+
206
+ if 'loc_left_eye' in data:
207
+ # get facial component locations, shape (batch, 4)
208
+ self.loc_left_eyes = data['loc_left_eye']
209
+ self.loc_right_eyes = data['loc_right_eye']
210
+ self.loc_mouths = data['loc_mouth']
211
+
212
+ # uncomment to check data
213
+ # import torchvision
214
+ # if self.opt['rank'] == 0:
215
+ # import os
216
+ # os.makedirs('tmp/gt', exist_ok=True)
217
+ # os.makedirs('tmp/lq', exist_ok=True)
218
+ # print(self.idx)
219
+ # torchvision.utils.save_image(
220
+ # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
221
+ # torchvision.utils.save_image(
222
+ # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
223
+ # self.idx = self.idx + 1
224
+
225
+ def construct_img_pyramid(self):
226
+ """Construct image pyramid for intermediate restoration loss"""
227
+ pyramid_gt = [self.gt]
228
+ down_img = self.gt
229
+ for _ in range(0, self.log_size - 3):
230
+ down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
231
+ pyramid_gt.insert(0, down_img)
232
+ return pyramid_gt
233
+
234
+ def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
235
+ face_ratio = int(self.opt['network_g']['out_size'] / 512)
236
+ eye_out_size *= face_ratio
237
+ mouth_out_size *= face_ratio
238
+
239
+ rois_eyes = []
240
+ rois_mouths = []
241
+ for b in range(self.loc_left_eyes.size(0)): # loop for batch size
242
+ # left eye and right eye
243
+ img_inds = self.loc_left_eyes.new_full((2, 1), b)
244
+ bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
245
+ rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
246
+ rois_eyes.append(rois)
247
+ # mouse
248
+ img_inds = self.loc_left_eyes.new_full((1, 1), b)
249
+ rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
250
+ rois_mouths.append(rois)
251
+
252
+ rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
253
+ rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
254
+
255
+ # real images
256
+ all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
257
+ self.left_eyes_gt = all_eyes[0::2, :, :, :]
258
+ self.right_eyes_gt = all_eyes[1::2, :, :, :]
259
+ self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
260
+ # output
261
+ all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
262
+ self.left_eyes = all_eyes[0::2, :, :, :]
263
+ self.right_eyes = all_eyes[1::2, :, :, :]
264
+ self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
265
+
266
+ def _gram_mat(self, x):
267
+ """Calculate Gram matrix.
268
+
269
+ Args:
270
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
271
+
272
+ Returns:
273
+ torch.Tensor: Gram matrix.
274
+ """
275
+ n, c, h, w = x.size()
276
+ features = x.view(n, c, w * h)
277
+ features_t = features.transpose(1, 2)
278
+ gram = features.bmm(features_t) / (c * h * w)
279
+ return gram
280
+
281
+ def gray_resize_for_identity(self, out, size=128):
282
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
283
+ out_gray = out_gray.unsqueeze(1)
284
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
285
+ return out_gray
286
+
287
+ def optimize_parameters(self, current_iter):
288
+ # optimize net_g
289
+ for p in self.net_d.parameters():
290
+ p.requires_grad = False
291
+ self.optimizer_g.zero_grad()
292
+
293
+ # do not update facial component net_d
294
+ if self.use_facial_disc:
295
+ for p in self.net_d_left_eye.parameters():
296
+ p.requires_grad = False
297
+ for p in self.net_d_right_eye.parameters():
298
+ p.requires_grad = False
299
+ for p in self.net_d_mouth.parameters():
300
+ p.requires_grad = False
301
+
302
+ # image pyramid loss weight
303
+ pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
304
+ if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
305
+ pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
306
+ if pyramid_loss_weight > 0:
307
+ self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
308
+ pyramid_gt = self.construct_img_pyramid()
309
+ else:
310
+ self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
311
+
312
+ # get roi-align regions
313
+ if self.use_facial_disc:
314
+ self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
315
+
316
+ l_g_total = 0
317
+ loss_dict = OrderedDict()
318
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
319
+ # pixel loss
320
+ if self.cri_pix:
321
+ l_g_pix = self.cri_pix(self.output, self.gt)
322
+ l_g_total += l_g_pix
323
+ loss_dict['l_g_pix'] = l_g_pix
324
+
325
+ # image pyramid loss
326
+ if pyramid_loss_weight > 0:
327
+ for i in range(0, self.log_size - 2):
328
+ l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
329
+ l_g_total += l_pyramid
330
+ loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
331
+
332
+ # perceptual loss
333
+ if self.cri_perceptual:
334
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
335
+ if l_g_percep is not None:
336
+ l_g_total += l_g_percep
337
+ loss_dict['l_g_percep'] = l_g_percep
338
+ if l_g_style is not None:
339
+ l_g_total += l_g_style
340
+ loss_dict['l_g_style'] = l_g_style
341
+
342
+ # gan loss
343
+ fake_g_pred = self.net_d(self.output)
344
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
345
+ l_g_total += l_g_gan
346
+ loss_dict['l_g_gan'] = l_g_gan
347
+
348
+ # facial component loss
349
+ if self.use_facial_disc:
350
+ # left eye
351
+ fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
352
+ l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
353
+ l_g_total += l_g_gan
354
+ loss_dict['l_g_gan_left_eye'] = l_g_gan
355
+ # right eye
356
+ fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
357
+ l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
358
+ l_g_total += l_g_gan
359
+ loss_dict['l_g_gan_right_eye'] = l_g_gan
360
+ # mouth
361
+ fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
362
+ l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
363
+ l_g_total += l_g_gan
364
+ loss_dict['l_g_gan_mouth'] = l_g_gan
365
+
366
+ if self.opt['train'].get('comp_style_weight', 0) > 0:
367
+ # get gt feat
368
+ _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
369
+ _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
370
+ _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
371
+
372
+ def _comp_style(feat, feat_gt, criterion):
373
+ return criterion(self._gram_mat(feat[0]), self._gram_mat(
374
+ feat_gt[0].detach())) * 0.5 + criterion(
375
+ self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
376
+
377
+ # facial component style loss
378
+ comp_style_loss = 0
379
+ comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
380
+ comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
381
+ comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
382
+ comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
383
+ l_g_total += comp_style_loss
384
+ loss_dict['l_g_comp_style_loss'] = comp_style_loss
385
+
386
+ # identity loss
387
+ if self.use_identity:
388
+ identity_weight = self.opt['train']['identity_weight']
389
+ # get gray images and resize
390
+ out_gray = self.gray_resize_for_identity(self.output)
391
+ gt_gray = self.gray_resize_for_identity(self.gt)
392
+
393
+ identity_gt = self.network_identity(gt_gray).detach()
394
+ identity_out = self.network_identity(out_gray)
395
+ l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
396
+ l_g_total += l_identity
397
+ loss_dict['l_identity'] = l_identity
398
+
399
+ l_g_total.backward()
400
+ self.optimizer_g.step()
401
+
402
+ # EMA
403
+ self.model_ema(decay=0.5**(32 / (10 * 1000)))
404
+
405
+ # ----------- optimize net_d ----------- #
406
+ for p in self.net_d.parameters():
407
+ p.requires_grad = True
408
+ self.optimizer_d.zero_grad()
409
+ if self.use_facial_disc:
410
+ for p in self.net_d_left_eye.parameters():
411
+ p.requires_grad = True
412
+ for p in self.net_d_right_eye.parameters():
413
+ p.requires_grad = True
414
+ for p in self.net_d_mouth.parameters():
415
+ p.requires_grad = True
416
+ self.optimizer_d_left_eye.zero_grad()
417
+ self.optimizer_d_right_eye.zero_grad()
418
+ self.optimizer_d_mouth.zero_grad()
419
+
420
+ fake_d_pred = self.net_d(self.output.detach())
421
+ real_d_pred = self.net_d(self.gt)
422
+ l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
423
+ loss_dict['l_d'] = l_d
424
+ # In WGAN, real_score should be positive and fake_score should be negative
425
+ loss_dict['real_score'] = real_d_pred.detach().mean()
426
+ loss_dict['fake_score'] = fake_d_pred.detach().mean()
427
+ l_d.backward()
428
+
429
+ # regularization loss
430
+ if current_iter % self.net_d_reg_every == 0:
431
+ self.gt.requires_grad = True
432
+ real_pred = self.net_d(self.gt)
433
+ l_d_r1 = r1_penalty(real_pred, self.gt)
434
+ l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
435
+ loss_dict['l_d_r1'] = l_d_r1.detach().mean()
436
+ l_d_r1.backward()
437
+
438
+ self.optimizer_d.step()
439
+
440
+ # optimize facial component discriminators
441
+ if self.use_facial_disc:
442
+ # left eye
443
+ fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
444
+ real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
445
+ l_d_left_eye = self.cri_component(
446
+ real_d_pred, True, is_disc=True) + self.cri_gan(
447
+ fake_d_pred, False, is_disc=True)
448
+ loss_dict['l_d_left_eye'] = l_d_left_eye
449
+ l_d_left_eye.backward()
450
+ # right eye
451
+ fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
452
+ real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
453
+ l_d_right_eye = self.cri_component(
454
+ real_d_pred, True, is_disc=True) + self.cri_gan(
455
+ fake_d_pred, False, is_disc=True)
456
+ loss_dict['l_d_right_eye'] = l_d_right_eye
457
+ l_d_right_eye.backward()
458
+ # mouth
459
+ fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
460
+ real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
461
+ l_d_mouth = self.cri_component(
462
+ real_d_pred, True, is_disc=True) + self.cri_gan(
463
+ fake_d_pred, False, is_disc=True)
464
+ loss_dict['l_d_mouth'] = l_d_mouth
465
+ l_d_mouth.backward()
466
+
467
+ self.optimizer_d_left_eye.step()
468
+ self.optimizer_d_right_eye.step()
469
+ self.optimizer_d_mouth.step()
470
+
471
+ self.log_dict = self.reduce_loss_dict(loss_dict)
472
+
473
+ def test(self):
474
+ with torch.no_grad():
475
+ if hasattr(self, 'net_g_ema'):
476
+ self.net_g_ema.eval()
477
+ self.output, _ = self.net_g_ema(self.lq)
478
+ else:
479
+ logger = get_root_logger()
480
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
481
+ self.net_g.eval()
482
+ self.output, _ = self.net_g(self.lq)
483
+ self.net_g.train()
484
+
485
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
486
+ if self.opt['rank'] == 0:
487
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
488
+
489
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
490
+ dataset_name = dataloader.dataset.opt['name']
491
+ with_metrics = self.opt['val'].get('metrics') is not None
492
+ use_pbar = self.opt['val'].get('pbar', False)
493
+
494
+ if with_metrics:
495
+ if not hasattr(self, 'metric_results'): # only execute in the first run
496
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
497
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
498
+ self._initialize_best_metric_results(dataset_name)
499
+ # zero self.metric_results
500
+ self.metric_results = {metric: 0 for metric in self.metric_results}
501
+
502
+ metric_data = dict()
503
+ if use_pbar:
504
+ pbar = tqdm(total=len(dataloader), unit='image')
505
+
506
+ for idx, val_data in enumerate(dataloader):
507
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
508
+ self.feed_data(val_data)
509
+ self.test()
510
+
511
+ sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
512
+ metric_data['img'] = sr_img
513
+ if hasattr(self, 'gt'):
514
+ gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
515
+ metric_data['img2'] = gt_img
516
+ del self.gt
517
+
518
+ # tentative for out of GPU memory
519
+ del self.lq
520
+ del self.output
521
+ torch.cuda.empty_cache()
522
+
523
+ if save_img:
524
+ if self.opt['is_train']:
525
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
526
+ f'{img_name}_{current_iter}.png')
527
+ else:
528
+ if self.opt['val']['suffix']:
529
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
530
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
531
+ else:
532
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
533
+ f'{img_name}_{self.opt["name"]}.png')
534
+ imwrite(sr_img, save_img_path)
535
+
536
+ if with_metrics:
537
+ # calculate metrics
538
+ for name, opt_ in self.opt['val']['metrics'].items():
539
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
540
+ if use_pbar:
541
+ pbar.update(1)
542
+ pbar.set_description(f'Test {img_name}')
543
+ if use_pbar:
544
+ pbar.close()
545
+
546
+ if with_metrics:
547
+ for metric in self.metric_results.keys():
548
+ self.metric_results[metric] /= (idx + 1)
549
+ # update the best metric result
550
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
551
+
552
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
553
+
554
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
555
+ log_str = f'Validation {dataset_name}\n'
556
+ for metric, value in self.metric_results.items():
557
+ log_str += f'\t # {metric}: {value:.4f}'
558
+ if hasattr(self, 'best_metric_results'):
559
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
560
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
561
+ log_str += '\n'
562
+
563
+ logger = get_root_logger()
564
+ logger.info(log_str)
565
+ if tb_logger:
566
+ for metric, value in self.metric_results.items():
567
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
568
+
569
+ def save(self, epoch, current_iter):
570
+ # save net_g and net_d
571
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
572
+ self.save_network(self.net_d, 'net_d', current_iter)
573
+ # save component discriminators
574
+ if self.use_facial_disc:
575
+ self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
576
+ self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
577
+ self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
578
+ # save training state
579
+ self.save_training_state(epoch, current_iter)
gfpgan/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import os.path as osp
3
+ from basicsr.train import train_pipeline
4
+
5
+ import gfpgan.archs
6
+ import gfpgan.data
7
+ import gfpgan.models
8
+
9
+ if __name__ == '__main__':
10
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11
+ train_pipeline(root_path)
gfpgan/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import torch
4
+ from basicsr.utils import img2tensor, tensor2img
5
+ from basicsr.utils.download_util import load_file_from_url
6
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
7
+ from torchvision.transforms.functional import normalize
8
+
9
+ from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
10
+ from gfpgan.archs.gfpganv1_arch import GFPGANv1
11
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
12
+
13
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+
16
+ class GFPGANer():
17
+ """Helper for restoration with GFPGAN.
18
+
19
+ It will detect and crop faces, and then resize the faces to 512x512.
20
+ GFPGAN is used to restored the resized faces.
21
+ The background is upsampled with the bg_upsampler.
22
+ Finally, the faces will be pasted back to the upsample background image.
23
+
24
+ Args:
25
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
26
+ upscale (float): The upscale of the final output. Default: 2.
27
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
28
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
29
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
30
+ """
31
+
32
+ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
33
+ self.upscale = upscale
34
+ self.bg_upsampler = bg_upsampler
35
+
36
+ # initialize model
37
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
38
+ # initialize the GFP-GAN
39
+ if arch == 'clean':
40
+ self.gfpgan = GFPGANv1Clean(
41
+ out_size=512,
42
+ num_style_feat=512,
43
+ channel_multiplier=channel_multiplier,
44
+ decoder_load_path=None,
45
+ fix_decoder=False,
46
+ num_mlp=8,
47
+ input_is_latent=True,
48
+ different_w=True,
49
+ narrow=1,
50
+ sft_half=True)
51
+ elif arch == 'bilinear':
52
+ self.gfpgan = GFPGANBilinear(
53
+ out_size=512,
54
+ num_style_feat=512,
55
+ channel_multiplier=channel_multiplier,
56
+ decoder_load_path=None,
57
+ fix_decoder=False,
58
+ num_mlp=8,
59
+ input_is_latent=True,
60
+ different_w=True,
61
+ narrow=1,
62
+ sft_half=True)
63
+ elif arch == 'original':
64
+ self.gfpgan = GFPGANv1(
65
+ out_size=512,
66
+ num_style_feat=512,
67
+ channel_multiplier=channel_multiplier,
68
+ decoder_load_path=None,
69
+ fix_decoder=True,
70
+ num_mlp=8,
71
+ input_is_latent=True,
72
+ different_w=True,
73
+ narrow=1,
74
+ sft_half=True)
75
+ elif arch == 'RestoreFormer':
76
+ from gfpgan.archs.restoreformer_arch import RestoreFormer
77
+ self.gfpgan = RestoreFormer()
78
+ # initialize face helper
79
+ self.face_helper = FaceRestoreHelper(
80
+ upscale,
81
+ face_size=512,
82
+ crop_ratio=(1, 1),
83
+ det_model='retinaface_resnet50',
84
+ save_ext='png',
85
+ use_parse=True,
86
+ device=self.device,
87
+ model_rootpath='gfpgan/weights')
88
+
89
+ if model_path.startswith('https://'):
90
+ model_path = load_file_from_url(
91
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
92
+ loadnet = torch.load(model_path)
93
+ if 'params_ema' in loadnet:
94
+ keyname = 'params_ema'
95
+ else:
96
+ keyname = 'params'
97
+ self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
98
+ self.gfpgan.eval()
99
+ self.gfpgan = self.gfpgan.to(self.device)
100
+
101
+ @torch.no_grad()
102
+ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
103
+ self.face_helper.clean_all()
104
+
105
+ if has_aligned: # the inputs are already aligned
106
+ img = cv2.resize(img, (512, 512))
107
+ self.face_helper.cropped_faces = [img]
108
+ else:
109
+ self.face_helper.read_image(img)
110
+ # get face landmarks for each face
111
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
112
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
113
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
114
+ # align and warp each face
115
+ self.face_helper.align_warp_face()
116
+
117
+ # face restoration
118
+ for cropped_face in self.face_helper.cropped_faces:
119
+ # prepare data
120
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
121
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
122
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
123
+
124
+ try:
125
+ output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
126
+ # convert to image
127
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
128
+ except RuntimeError as error:
129
+ print(f'\tFailed inference for GFPGAN: {error}.')
130
+ restored_face = cropped_face
131
+
132
+ restored_face = restored_face.astype('uint8')
133
+ self.face_helper.add_restored_face(restored_face)
134
+
135
+ if not has_aligned and paste_back:
136
+ # upsample the background
137
+ if self.bg_upsampler is not None:
138
+ # Now only support RealESRGAN for upsampling background
139
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
140
+ else:
141
+ bg_img = None
142
+
143
+ self.face_helper.get_inverse_affine(None)
144
+ # paste each restored face to the input image
145
+ restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
146
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
147
+ else:
148
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
gfpgan/version.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # GENERATED VERSION FILE
2
+ # TIME: Sat May 25 17:25:12 2024
3
+ __version__ = '1.3.8'
4
+ __gitsha__ = '7552a77'
5
+ version_info = (1, 3, 8)
gfpgan/weights/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Weights
2
+
3
+ Put the downloaded weights to this folder.
gfpgan/weights/detection_Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
gfpgan/weights/parsing_parsenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193