Jacobmadwed
commited on
Commit
•
2702c71
1
Parent(s):
f24cb6a
Upload 34 files
Browse files- gfpgan/.DS_Store +0 -0
- gfpgan/__init__.py +7 -0
- gfpgan/__pycache__/__init__.cpython-312.pyc +0 -0
- gfpgan/__pycache__/utils.cpython-312.pyc +0 -0
- gfpgan/archs/__init__.py +10 -0
- gfpgan/archs/__pycache__/__init__.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/arcface_arch.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/gfpgan_bilinear_arch.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/gfpganv1_arch.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/gfpganv1_clean_arch.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/restoreformer_arch.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/stylegan2_bilinear_arch.cpython-312.pyc +0 -0
- gfpgan/archs/__pycache__/stylegan2_clean_arch.cpython-312.pyc +0 -0
- gfpgan/archs/arcface_arch.py +245 -0
- gfpgan/archs/gfpgan_bilinear_arch.py +312 -0
- gfpgan/archs/gfpganv1_arch.py +439 -0
- gfpgan/archs/gfpganv1_clean_arch.py +324 -0
- gfpgan/archs/restoreformer_arch.py +658 -0
- gfpgan/archs/stylegan2_bilinear_arch.py +613 -0
- gfpgan/archs/stylegan2_clean_arch.py +368 -0
- gfpgan/data/__init__.py +10 -0
- gfpgan/data/__pycache__/__init__.cpython-312.pyc +0 -0
- gfpgan/data/__pycache__/ffhq_degradation_dataset.cpython-312.pyc +0 -0
- gfpgan/data/ffhq_degradation_dataset.py +230 -0
- gfpgan/models/__init__.py +10 -0
- gfpgan/models/__pycache__/__init__.cpython-312.pyc +0 -0
- gfpgan/models/__pycache__/gfpgan_model.cpython-312.pyc +0 -0
- gfpgan/models/gfpgan_model.py +579 -0
- gfpgan/train.py +11 -0
- gfpgan/utils.py +148 -0
- gfpgan/version.py +5 -0
- gfpgan/weights/README.md +3 -0
- gfpgan/weights/detection_Resnet50_Final.pth +3 -0
- gfpgan/weights/parsing_parsenet.pth +3 -0
gfpgan/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
gfpgan/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
from .archs import *
|
3 |
+
from .data import *
|
4 |
+
from .models import *
|
5 |
+
from .utils import *
|
6 |
+
|
7 |
+
# from .version import *
|
gfpgan/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (238 Bytes). View file
|
|
gfpgan/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (6.61 kB). View file
|
|
gfpgan/archs/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import arch modules for registry
|
6 |
+
# scan all the files that end with '_arch.py' under the archs folder
|
7 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
9 |
+
# import all the arch modules
|
10 |
+
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
|
gfpgan/archs/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (884 Bytes). View file
|
|
gfpgan/archs/__pycache__/arcface_arch.cpython-312.pyc
ADDED
Binary file (13.2 kB). View file
|
|
gfpgan/archs/__pycache__/gfpgan_bilinear_arch.cpython-312.pyc
ADDED
Binary file (14.2 kB). View file
|
|
gfpgan/archs/__pycache__/gfpganv1_arch.cpython-312.pyc
ADDED
Binary file (20 kB). View file
|
|
gfpgan/archs/__pycache__/gfpganv1_clean_arch.cpython-312.pyc
ADDED
Binary file (15.6 kB). View file
|
|
gfpgan/archs/__pycache__/restoreformer_arch.cpython-312.pyc
ADDED
Binary file (28.5 kB). View file
|
|
gfpgan/archs/__pycache__/stylegan2_bilinear_arch.cpython-312.pyc
ADDED
Binary file (27.9 kB). View file
|
|
gfpgan/archs/__pycache__/stylegan2_clean_arch.cpython-312.pyc
ADDED
Binary file (18.5 kB). View file
|
|
gfpgan/archs/arcface_arch.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
3 |
+
|
4 |
+
|
5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
inplanes (int): Channel number of inputs.
|
10 |
+
outplanes (int): Channel number of outputs.
|
11 |
+
stride (int): Stride in convolution. Default: 1.
|
12 |
+
"""
|
13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
14 |
+
|
15 |
+
|
16 |
+
class BasicBlock(nn.Module):
|
17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
inplanes (int): Channel number of inputs.
|
21 |
+
planes (int): Channel number of outputs.
|
22 |
+
stride (int): Stride in convolution. Default: 1.
|
23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
24 |
+
"""
|
25 |
+
expansion = 1 # output channel expansion ratio
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
residual = self.downsample(x)
|
49 |
+
|
50 |
+
out += residual
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class IRBlock(nn.Module):
|
57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
inplanes (int): Channel number of inputs.
|
61 |
+
planes (int): Channel number of outputs.
|
62 |
+
stride (int): Stride in convolution. Default: 1.
|
63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
65 |
+
"""
|
66 |
+
expansion = 1 # output channel expansion ratio
|
67 |
+
|
68 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
69 |
+
super(IRBlock, self).__init__()
|
70 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
71 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
72 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
73 |
+
self.prelu = nn.PReLU()
|
74 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
75 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
76 |
+
self.downsample = downsample
|
77 |
+
self.stride = stride
|
78 |
+
self.use_se = use_se
|
79 |
+
if self.use_se:
|
80 |
+
self.se = SEBlock(planes)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
residual = x
|
84 |
+
out = self.bn0(x)
|
85 |
+
out = self.conv1(out)
|
86 |
+
out = self.bn1(out)
|
87 |
+
out = self.prelu(out)
|
88 |
+
|
89 |
+
out = self.conv2(out)
|
90 |
+
out = self.bn2(out)
|
91 |
+
if self.use_se:
|
92 |
+
out = self.se(out)
|
93 |
+
|
94 |
+
if self.downsample is not None:
|
95 |
+
residual = self.downsample(x)
|
96 |
+
|
97 |
+
out += residual
|
98 |
+
out = self.prelu(out)
|
99 |
+
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class Bottleneck(nn.Module):
|
104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
inplanes (int): Channel number of inputs.
|
108 |
+
planes (int): Channel number of outputs.
|
109 |
+
stride (int): Stride in convolution. Default: 1.
|
110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
111 |
+
"""
|
112 |
+
expansion = 4 # output channel expansion ratio
|
113 |
+
|
114 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
115 |
+
super(Bottleneck, self).__init__()
|
116 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
117 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
118 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
119 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
120 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
121 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
122 |
+
self.relu = nn.ReLU(inplace=True)
|
123 |
+
self.downsample = downsample
|
124 |
+
self.stride = stride
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
residual = x
|
128 |
+
|
129 |
+
out = self.conv1(x)
|
130 |
+
out = self.bn1(out)
|
131 |
+
out = self.relu(out)
|
132 |
+
|
133 |
+
out = self.conv2(out)
|
134 |
+
out = self.bn2(out)
|
135 |
+
out = self.relu(out)
|
136 |
+
|
137 |
+
out = self.conv3(out)
|
138 |
+
out = self.bn3(out)
|
139 |
+
|
140 |
+
if self.downsample is not None:
|
141 |
+
residual = self.downsample(x)
|
142 |
+
|
143 |
+
out += residual
|
144 |
+
out = self.relu(out)
|
145 |
+
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class SEBlock(nn.Module):
|
150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
channel (int): Channel number of inputs.
|
154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, channel, reduction=16):
|
158 |
+
super(SEBlock, self).__init__()
|
159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
160 |
+
self.fc = nn.Sequential(
|
161 |
+
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
162 |
+
nn.Sigmoid())
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
b, c, _, _ = x.size()
|
166 |
+
y = self.avg_pool(x).view(b, c)
|
167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
168 |
+
return x * y
|
169 |
+
|
170 |
+
|
171 |
+
@ARCH_REGISTRY.register()
|
172 |
+
class ResNetArcFace(nn.Module):
|
173 |
+
"""ArcFace with ResNet architectures.
|
174 |
+
|
175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
block (str): Block used in the ArcFace architecture.
|
179 |
+
layers (tuple(int)): Block numbers in each layer.
|
180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, block, layers, use_se=True):
|
184 |
+
if block == 'IRBlock':
|
185 |
+
block = IRBlock
|
186 |
+
self.inplanes = 64
|
187 |
+
self.use_se = use_se
|
188 |
+
super(ResNetArcFace, self).__init__()
|
189 |
+
|
190 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
192 |
+
self.prelu = nn.PReLU()
|
193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
198 |
+
self.bn4 = nn.BatchNorm2d(512)
|
199 |
+
self.dropout = nn.Dropout()
|
200 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
201 |
+
self.bn5 = nn.BatchNorm1d(512)
|
202 |
+
|
203 |
+
# initialization
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv2d):
|
206 |
+
nn.init.xavier_normal_(m.weight)
|
207 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
208 |
+
nn.init.constant_(m.weight, 1)
|
209 |
+
nn.init.constant_(m.bias, 0)
|
210 |
+
elif isinstance(m, nn.Linear):
|
211 |
+
nn.init.xavier_normal_(m.weight)
|
212 |
+
nn.init.constant_(m.bias, 0)
|
213 |
+
|
214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
215 |
+
downsample = None
|
216 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
217 |
+
downsample = nn.Sequential(
|
218 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
219 |
+
nn.BatchNorm2d(planes * block.expansion),
|
220 |
+
)
|
221 |
+
layers = []
|
222 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
223 |
+
self.inplanes = planes
|
224 |
+
for _ in range(1, num_blocks):
|
225 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
226 |
+
|
227 |
+
return nn.Sequential(*layers)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
x = self.conv1(x)
|
231 |
+
x = self.bn1(x)
|
232 |
+
x = self.prelu(x)
|
233 |
+
x = self.maxpool(x)
|
234 |
+
|
235 |
+
x = self.layer1(x)
|
236 |
+
x = self.layer2(x)
|
237 |
+
x = self.layer3(x)
|
238 |
+
x = self.layer4(x)
|
239 |
+
x = self.bn4(x)
|
240 |
+
x = self.dropout(x)
|
241 |
+
x = x.view(x.size(0), -1)
|
242 |
+
x = self.fc5(x)
|
243 |
+
x = self.bn5(x)
|
244 |
+
|
245 |
+
return x
|
gfpgan/archs/gfpgan_bilinear_arch.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from .gfpganv1_arch import ResUpBlock
|
8 |
+
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
9 |
+
StyleGAN2GeneratorBilinear)
|
10 |
+
|
11 |
+
|
12 |
+
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
|
13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
14 |
+
|
15 |
+
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
|
16 |
+
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
out_size (int): The spatial size of outputs.
|
20 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
21 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
22 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
23 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
24 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
25 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
out_size,
|
30 |
+
num_style_feat=512,
|
31 |
+
num_mlp=8,
|
32 |
+
channel_multiplier=2,
|
33 |
+
lr_mlp=0.01,
|
34 |
+
narrow=1,
|
35 |
+
sft_half=False):
|
36 |
+
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
|
37 |
+
out_size,
|
38 |
+
num_style_feat=num_style_feat,
|
39 |
+
num_mlp=num_mlp,
|
40 |
+
channel_multiplier=channel_multiplier,
|
41 |
+
lr_mlp=lr_mlp,
|
42 |
+
narrow=narrow)
|
43 |
+
self.sft_half = sft_half
|
44 |
+
|
45 |
+
def forward(self,
|
46 |
+
styles,
|
47 |
+
conditions,
|
48 |
+
input_is_latent=False,
|
49 |
+
noise=None,
|
50 |
+
randomize_noise=True,
|
51 |
+
truncation=1,
|
52 |
+
truncation_latent=None,
|
53 |
+
inject_index=None,
|
54 |
+
return_latents=False):
|
55 |
+
"""Forward function for StyleGAN2GeneratorBilinearSFT.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
styles (list[Tensor]): Sample codes of styles.
|
59 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
60 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
61 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
62 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
63 |
+
truncation (float): The truncation ratio. Default: 1.
|
64 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
65 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
66 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
67 |
+
"""
|
68 |
+
# style codes -> latents with Style MLP layer
|
69 |
+
if not input_is_latent:
|
70 |
+
styles = [self.style_mlp(s) for s in styles]
|
71 |
+
# noises
|
72 |
+
if noise is None:
|
73 |
+
if randomize_noise:
|
74 |
+
noise = [None] * self.num_layers # for each style conv layer
|
75 |
+
else: # use the stored noise
|
76 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
77 |
+
# style truncation
|
78 |
+
if truncation < 1:
|
79 |
+
style_truncation = []
|
80 |
+
for style in styles:
|
81 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
82 |
+
styles = style_truncation
|
83 |
+
# get style latents with injection
|
84 |
+
if len(styles) == 1:
|
85 |
+
inject_index = self.num_latent
|
86 |
+
|
87 |
+
if styles[0].ndim < 3:
|
88 |
+
# repeat latent code for all the layers
|
89 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
90 |
+
else: # used for encoder with different latent code for each layer
|
91 |
+
latent = styles[0]
|
92 |
+
elif len(styles) == 2: # mixing noises
|
93 |
+
if inject_index is None:
|
94 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
95 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
96 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
97 |
+
latent = torch.cat([latent1, latent2], 1)
|
98 |
+
|
99 |
+
# main generation
|
100 |
+
out = self.constant_input(latent.shape[0])
|
101 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
102 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
103 |
+
|
104 |
+
i = 1
|
105 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
106 |
+
noise[2::2], self.to_rgbs):
|
107 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
108 |
+
|
109 |
+
# the conditions may have fewer levels
|
110 |
+
if i < len(conditions):
|
111 |
+
# SFT part to combine the conditions
|
112 |
+
if self.sft_half: # only apply SFT to half of the channels
|
113 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
114 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
115 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
116 |
+
else: # apply SFT to all the channels
|
117 |
+
out = out * conditions[i - 1] + conditions[i]
|
118 |
+
|
119 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
120 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
121 |
+
i += 2
|
122 |
+
|
123 |
+
image = skip
|
124 |
+
|
125 |
+
if return_latents:
|
126 |
+
return image, latent
|
127 |
+
else:
|
128 |
+
return image, None
|
129 |
+
|
130 |
+
|
131 |
+
@ARCH_REGISTRY.register()
|
132 |
+
class GFPGANBilinear(nn.Module):
|
133 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
134 |
+
|
135 |
+
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
|
136 |
+
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
|
137 |
+
|
138 |
+
|
139 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
out_size (int): The spatial size of outputs.
|
143 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
144 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
145 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
146 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
147 |
+
|
148 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
149 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
150 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
151 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
152 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
153 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
out_size,
|
159 |
+
num_style_feat=512,
|
160 |
+
channel_multiplier=1,
|
161 |
+
decoder_load_path=None,
|
162 |
+
fix_decoder=True,
|
163 |
+
# for stylegan decoder
|
164 |
+
num_mlp=8,
|
165 |
+
lr_mlp=0.01,
|
166 |
+
input_is_latent=False,
|
167 |
+
different_w=False,
|
168 |
+
narrow=1,
|
169 |
+
sft_half=False):
|
170 |
+
|
171 |
+
super(GFPGANBilinear, self).__init__()
|
172 |
+
self.input_is_latent = input_is_latent
|
173 |
+
self.different_w = different_w
|
174 |
+
self.num_style_feat = num_style_feat
|
175 |
+
|
176 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
177 |
+
channels = {
|
178 |
+
'4': int(512 * unet_narrow),
|
179 |
+
'8': int(512 * unet_narrow),
|
180 |
+
'16': int(512 * unet_narrow),
|
181 |
+
'32': int(512 * unet_narrow),
|
182 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
183 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
184 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
185 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
186 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
187 |
+
}
|
188 |
+
|
189 |
+
self.log_size = int(math.log(out_size, 2))
|
190 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
191 |
+
|
192 |
+
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
193 |
+
|
194 |
+
# downsample
|
195 |
+
in_channels = channels[f'{first_out_size}']
|
196 |
+
self.conv_body_down = nn.ModuleList()
|
197 |
+
for i in range(self.log_size, 2, -1):
|
198 |
+
out_channels = channels[f'{2**(i - 1)}']
|
199 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels))
|
200 |
+
in_channels = out_channels
|
201 |
+
|
202 |
+
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
203 |
+
|
204 |
+
# upsample
|
205 |
+
in_channels = channels['4']
|
206 |
+
self.conv_body_up = nn.ModuleList()
|
207 |
+
for i in range(3, self.log_size + 1):
|
208 |
+
out_channels = channels[f'{2**i}']
|
209 |
+
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
210 |
+
in_channels = out_channels
|
211 |
+
|
212 |
+
# to RGB
|
213 |
+
self.toRGB = nn.ModuleList()
|
214 |
+
for i in range(3, self.log_size + 1):
|
215 |
+
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
216 |
+
|
217 |
+
if different_w:
|
218 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
219 |
+
else:
|
220 |
+
linear_out_channel = num_style_feat
|
221 |
+
|
222 |
+
self.final_linear = EqualLinear(
|
223 |
+
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
224 |
+
|
225 |
+
# the decoder: stylegan2 generator with SFT modulations
|
226 |
+
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
|
227 |
+
out_size=out_size,
|
228 |
+
num_style_feat=num_style_feat,
|
229 |
+
num_mlp=num_mlp,
|
230 |
+
channel_multiplier=channel_multiplier,
|
231 |
+
lr_mlp=lr_mlp,
|
232 |
+
narrow=narrow,
|
233 |
+
sft_half=sft_half)
|
234 |
+
|
235 |
+
# load pre-trained stylegan2 model if necessary
|
236 |
+
if decoder_load_path:
|
237 |
+
self.stylegan_decoder.load_state_dict(
|
238 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
239 |
+
# fix decoder without updating params
|
240 |
+
if fix_decoder:
|
241 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
242 |
+
param.requires_grad = False
|
243 |
+
|
244 |
+
# for SFT modulations (scale and shift)
|
245 |
+
self.condition_scale = nn.ModuleList()
|
246 |
+
self.condition_shift = nn.ModuleList()
|
247 |
+
for i in range(3, self.log_size + 1):
|
248 |
+
out_channels = channels[f'{2**i}']
|
249 |
+
if sft_half:
|
250 |
+
sft_out_channels = out_channels
|
251 |
+
else:
|
252 |
+
sft_out_channels = out_channels * 2
|
253 |
+
self.condition_scale.append(
|
254 |
+
nn.Sequential(
|
255 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
256 |
+
ScaledLeakyReLU(0.2),
|
257 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
258 |
+
self.condition_shift.append(
|
259 |
+
nn.Sequential(
|
260 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
261 |
+
ScaledLeakyReLU(0.2),
|
262 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
263 |
+
|
264 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
265 |
+
"""Forward function for GFPGANBilinear.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
x (Tensor): Input images.
|
269 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
270 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
271 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
272 |
+
"""
|
273 |
+
conditions = []
|
274 |
+
unet_skips = []
|
275 |
+
out_rgbs = []
|
276 |
+
|
277 |
+
# encoder
|
278 |
+
feat = self.conv_body_first(x)
|
279 |
+
for i in range(self.log_size - 2):
|
280 |
+
feat = self.conv_body_down[i](feat)
|
281 |
+
unet_skips.insert(0, feat)
|
282 |
+
|
283 |
+
feat = self.final_conv(feat)
|
284 |
+
|
285 |
+
# style code
|
286 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
287 |
+
if self.different_w:
|
288 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
289 |
+
|
290 |
+
# decode
|
291 |
+
for i in range(self.log_size - 2):
|
292 |
+
# add unet skip
|
293 |
+
feat = feat + unet_skips[i]
|
294 |
+
# ResUpLayer
|
295 |
+
feat = self.conv_body_up[i](feat)
|
296 |
+
# generate scale and shift for SFT layers
|
297 |
+
scale = self.condition_scale[i](feat)
|
298 |
+
conditions.append(scale.clone())
|
299 |
+
shift = self.condition_shift[i](feat)
|
300 |
+
conditions.append(shift.clone())
|
301 |
+
# generate rgb images
|
302 |
+
if return_rgb:
|
303 |
+
out_rgbs.append(self.toRGB[i](feat))
|
304 |
+
|
305 |
+
# decoder
|
306 |
+
image, _ = self.stylegan_decoder([style_code],
|
307 |
+
conditions,
|
308 |
+
return_latents=return_latents,
|
309 |
+
input_is_latent=self.input_is_latent,
|
310 |
+
randomize_noise=randomize_noise)
|
311 |
+
|
312 |
+
return image, out_rgbs
|
gfpgan/archs/gfpganv1_arch.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
5 |
+
StyleGAN2Generator)
|
6 |
+
from basicsr.ops.fused_act import FusedLeakyReLU
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
14 |
+
|
15 |
+
Args:
|
16 |
+
out_size (int): The spatial size of outputs.
|
17 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
18 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
19 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
20 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
21 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
22 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
23 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
24 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
out_size,
|
29 |
+
num_style_feat=512,
|
30 |
+
num_mlp=8,
|
31 |
+
channel_multiplier=2,
|
32 |
+
resample_kernel=(1, 3, 3, 1),
|
33 |
+
lr_mlp=0.01,
|
34 |
+
narrow=1,
|
35 |
+
sft_half=False):
|
36 |
+
super(StyleGAN2GeneratorSFT, self).__init__(
|
37 |
+
out_size,
|
38 |
+
num_style_feat=num_style_feat,
|
39 |
+
num_mlp=num_mlp,
|
40 |
+
channel_multiplier=channel_multiplier,
|
41 |
+
resample_kernel=resample_kernel,
|
42 |
+
lr_mlp=lr_mlp,
|
43 |
+
narrow=narrow)
|
44 |
+
self.sft_half = sft_half
|
45 |
+
|
46 |
+
def forward(self,
|
47 |
+
styles,
|
48 |
+
conditions,
|
49 |
+
input_is_latent=False,
|
50 |
+
noise=None,
|
51 |
+
randomize_noise=True,
|
52 |
+
truncation=1,
|
53 |
+
truncation_latent=None,
|
54 |
+
inject_index=None,
|
55 |
+
return_latents=False):
|
56 |
+
"""Forward function for StyleGAN2GeneratorSFT.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
styles (list[Tensor]): Sample codes of styles.
|
60 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
61 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
62 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
63 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
64 |
+
truncation (float): The truncation ratio. Default: 1.
|
65 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
66 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
67 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
68 |
+
"""
|
69 |
+
# style codes -> latents with Style MLP layer
|
70 |
+
if not input_is_latent:
|
71 |
+
styles = [self.style_mlp(s) for s in styles]
|
72 |
+
# noises
|
73 |
+
if noise is None:
|
74 |
+
if randomize_noise:
|
75 |
+
noise = [None] * self.num_layers # for each style conv layer
|
76 |
+
else: # use the stored noise
|
77 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
78 |
+
# style truncation
|
79 |
+
if truncation < 1:
|
80 |
+
style_truncation = []
|
81 |
+
for style in styles:
|
82 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
83 |
+
styles = style_truncation
|
84 |
+
# get style latents with injection
|
85 |
+
if len(styles) == 1:
|
86 |
+
inject_index = self.num_latent
|
87 |
+
|
88 |
+
if styles[0].ndim < 3:
|
89 |
+
# repeat latent code for all the layers
|
90 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
91 |
+
else: # used for encoder with different latent code for each layer
|
92 |
+
latent = styles[0]
|
93 |
+
elif len(styles) == 2: # mixing noises
|
94 |
+
if inject_index is None:
|
95 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
96 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
97 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
98 |
+
latent = torch.cat([latent1, latent2], 1)
|
99 |
+
|
100 |
+
# main generation
|
101 |
+
out = self.constant_input(latent.shape[0])
|
102 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
103 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
104 |
+
|
105 |
+
i = 1
|
106 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
107 |
+
noise[2::2], self.to_rgbs):
|
108 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
109 |
+
|
110 |
+
# the conditions may have fewer levels
|
111 |
+
if i < len(conditions):
|
112 |
+
# SFT part to combine the conditions
|
113 |
+
if self.sft_half: # only apply SFT to half of the channels
|
114 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
115 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
116 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
117 |
+
else: # apply SFT to all the channels
|
118 |
+
out = out * conditions[i - 1] + conditions[i]
|
119 |
+
|
120 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
121 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
122 |
+
i += 2
|
123 |
+
|
124 |
+
image = skip
|
125 |
+
|
126 |
+
if return_latents:
|
127 |
+
return image, latent
|
128 |
+
else:
|
129 |
+
return image, None
|
130 |
+
|
131 |
+
|
132 |
+
class ConvUpLayer(nn.Module):
|
133 |
+
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
in_channels (int): Channel number of the input.
|
137 |
+
out_channels (int): Channel number of the output.
|
138 |
+
kernel_size (int): Size of the convolving kernel.
|
139 |
+
stride (int): Stride of the convolution. Default: 1
|
140 |
+
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
141 |
+
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
142 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
143 |
+
activate (bool): Whether use activateion. Default: True.
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self,
|
147 |
+
in_channels,
|
148 |
+
out_channels,
|
149 |
+
kernel_size,
|
150 |
+
stride=1,
|
151 |
+
padding=0,
|
152 |
+
bias=True,
|
153 |
+
bias_init_val=0,
|
154 |
+
activate=True):
|
155 |
+
super(ConvUpLayer, self).__init__()
|
156 |
+
self.in_channels = in_channels
|
157 |
+
self.out_channels = out_channels
|
158 |
+
self.kernel_size = kernel_size
|
159 |
+
self.stride = stride
|
160 |
+
self.padding = padding
|
161 |
+
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
162 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
163 |
+
|
164 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
165 |
+
|
166 |
+
if bias and not activate:
|
167 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
168 |
+
else:
|
169 |
+
self.register_parameter('bias', None)
|
170 |
+
|
171 |
+
# activation
|
172 |
+
if activate:
|
173 |
+
if bias:
|
174 |
+
self.activation = FusedLeakyReLU(out_channels)
|
175 |
+
else:
|
176 |
+
self.activation = ScaledLeakyReLU(0.2)
|
177 |
+
else:
|
178 |
+
self.activation = None
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
# bilinear upsample
|
182 |
+
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
183 |
+
# conv
|
184 |
+
out = F.conv2d(
|
185 |
+
out,
|
186 |
+
self.weight * self.scale,
|
187 |
+
bias=self.bias,
|
188 |
+
stride=self.stride,
|
189 |
+
padding=self.padding,
|
190 |
+
)
|
191 |
+
# activation
|
192 |
+
if self.activation is not None:
|
193 |
+
out = self.activation(out)
|
194 |
+
return out
|
195 |
+
|
196 |
+
|
197 |
+
class ResUpBlock(nn.Module):
|
198 |
+
"""Residual block with upsampling.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
in_channels (int): Channel number of the input.
|
202 |
+
out_channels (int): Channel number of the output.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, in_channels, out_channels):
|
206 |
+
super(ResUpBlock, self).__init__()
|
207 |
+
|
208 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
209 |
+
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
|
210 |
+
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
out = self.conv1(x)
|
214 |
+
out = self.conv2(out)
|
215 |
+
skip = self.skip(x)
|
216 |
+
out = (out + skip) / math.sqrt(2)
|
217 |
+
return out
|
218 |
+
|
219 |
+
|
220 |
+
@ARCH_REGISTRY.register()
|
221 |
+
class GFPGANv1(nn.Module):
|
222 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
223 |
+
|
224 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
out_size (int): The spatial size of outputs.
|
228 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
229 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
230 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
231 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
232 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
233 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
234 |
+
|
235 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
236 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
237 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
238 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
239 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
240 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
241 |
+
"""
|
242 |
+
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
out_size,
|
246 |
+
num_style_feat=512,
|
247 |
+
channel_multiplier=1,
|
248 |
+
resample_kernel=(1, 3, 3, 1),
|
249 |
+
decoder_load_path=None,
|
250 |
+
fix_decoder=True,
|
251 |
+
# for stylegan decoder
|
252 |
+
num_mlp=8,
|
253 |
+
lr_mlp=0.01,
|
254 |
+
input_is_latent=False,
|
255 |
+
different_w=False,
|
256 |
+
narrow=1,
|
257 |
+
sft_half=False):
|
258 |
+
|
259 |
+
super(GFPGANv1, self).__init__()
|
260 |
+
self.input_is_latent = input_is_latent
|
261 |
+
self.different_w = different_w
|
262 |
+
self.num_style_feat = num_style_feat
|
263 |
+
|
264 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
265 |
+
channels = {
|
266 |
+
'4': int(512 * unet_narrow),
|
267 |
+
'8': int(512 * unet_narrow),
|
268 |
+
'16': int(512 * unet_narrow),
|
269 |
+
'32': int(512 * unet_narrow),
|
270 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
271 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
272 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
273 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
274 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
275 |
+
}
|
276 |
+
|
277 |
+
self.log_size = int(math.log(out_size, 2))
|
278 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
279 |
+
|
280 |
+
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
281 |
+
|
282 |
+
# downsample
|
283 |
+
in_channels = channels[f'{first_out_size}']
|
284 |
+
self.conv_body_down = nn.ModuleList()
|
285 |
+
for i in range(self.log_size, 2, -1):
|
286 |
+
out_channels = channels[f'{2**(i - 1)}']
|
287 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
|
288 |
+
in_channels = out_channels
|
289 |
+
|
290 |
+
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
291 |
+
|
292 |
+
# upsample
|
293 |
+
in_channels = channels['4']
|
294 |
+
self.conv_body_up = nn.ModuleList()
|
295 |
+
for i in range(3, self.log_size + 1):
|
296 |
+
out_channels = channels[f'{2**i}']
|
297 |
+
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
298 |
+
in_channels = out_channels
|
299 |
+
|
300 |
+
# to RGB
|
301 |
+
self.toRGB = nn.ModuleList()
|
302 |
+
for i in range(3, self.log_size + 1):
|
303 |
+
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
304 |
+
|
305 |
+
if different_w:
|
306 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
307 |
+
else:
|
308 |
+
linear_out_channel = num_style_feat
|
309 |
+
|
310 |
+
self.final_linear = EqualLinear(
|
311 |
+
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
312 |
+
|
313 |
+
# the decoder: stylegan2 generator with SFT modulations
|
314 |
+
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
315 |
+
out_size=out_size,
|
316 |
+
num_style_feat=num_style_feat,
|
317 |
+
num_mlp=num_mlp,
|
318 |
+
channel_multiplier=channel_multiplier,
|
319 |
+
resample_kernel=resample_kernel,
|
320 |
+
lr_mlp=lr_mlp,
|
321 |
+
narrow=narrow,
|
322 |
+
sft_half=sft_half)
|
323 |
+
|
324 |
+
# load pre-trained stylegan2 model if necessary
|
325 |
+
if decoder_load_path:
|
326 |
+
self.stylegan_decoder.load_state_dict(
|
327 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
328 |
+
# fix decoder without updating params
|
329 |
+
if fix_decoder:
|
330 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
331 |
+
param.requires_grad = False
|
332 |
+
|
333 |
+
# for SFT modulations (scale and shift)
|
334 |
+
self.condition_scale = nn.ModuleList()
|
335 |
+
self.condition_shift = nn.ModuleList()
|
336 |
+
for i in range(3, self.log_size + 1):
|
337 |
+
out_channels = channels[f'{2**i}']
|
338 |
+
if sft_half:
|
339 |
+
sft_out_channels = out_channels
|
340 |
+
else:
|
341 |
+
sft_out_channels = out_channels * 2
|
342 |
+
self.condition_scale.append(
|
343 |
+
nn.Sequential(
|
344 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
345 |
+
ScaledLeakyReLU(0.2),
|
346 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
347 |
+
self.condition_shift.append(
|
348 |
+
nn.Sequential(
|
349 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
350 |
+
ScaledLeakyReLU(0.2),
|
351 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
352 |
+
|
353 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
354 |
+
"""Forward function for GFPGANv1.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
x (Tensor): Input images.
|
358 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
359 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
360 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
361 |
+
"""
|
362 |
+
conditions = []
|
363 |
+
unet_skips = []
|
364 |
+
out_rgbs = []
|
365 |
+
|
366 |
+
# encoder
|
367 |
+
feat = self.conv_body_first(x)
|
368 |
+
for i in range(self.log_size - 2):
|
369 |
+
feat = self.conv_body_down[i](feat)
|
370 |
+
unet_skips.insert(0, feat)
|
371 |
+
|
372 |
+
feat = self.final_conv(feat)
|
373 |
+
|
374 |
+
# style code
|
375 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
376 |
+
if self.different_w:
|
377 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
378 |
+
|
379 |
+
# decode
|
380 |
+
for i in range(self.log_size - 2):
|
381 |
+
# add unet skip
|
382 |
+
feat = feat + unet_skips[i]
|
383 |
+
# ResUpLayer
|
384 |
+
feat = self.conv_body_up[i](feat)
|
385 |
+
# generate scale and shift for SFT layers
|
386 |
+
scale = self.condition_scale[i](feat)
|
387 |
+
conditions.append(scale.clone())
|
388 |
+
shift = self.condition_shift[i](feat)
|
389 |
+
conditions.append(shift.clone())
|
390 |
+
# generate rgb images
|
391 |
+
if return_rgb:
|
392 |
+
out_rgbs.append(self.toRGB[i](feat))
|
393 |
+
|
394 |
+
# decoder
|
395 |
+
image, _ = self.stylegan_decoder([style_code],
|
396 |
+
conditions,
|
397 |
+
return_latents=return_latents,
|
398 |
+
input_is_latent=self.input_is_latent,
|
399 |
+
randomize_noise=randomize_noise)
|
400 |
+
|
401 |
+
return image, out_rgbs
|
402 |
+
|
403 |
+
|
404 |
+
@ARCH_REGISTRY.register()
|
405 |
+
class FacialComponentDiscriminator(nn.Module):
|
406 |
+
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(self):
|
410 |
+
super(FacialComponentDiscriminator, self).__init__()
|
411 |
+
# It now uses a VGG-style architectrue with fixed model size
|
412 |
+
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
413 |
+
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
414 |
+
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
415 |
+
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
416 |
+
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
417 |
+
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
418 |
+
|
419 |
+
def forward(self, x, return_feats=False, **kwargs):
|
420 |
+
"""Forward function for FacialComponentDiscriminator.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
x (Tensor): Input images.
|
424 |
+
return_feats (bool): Whether to return intermediate features. Default: False.
|
425 |
+
"""
|
426 |
+
feat = self.conv1(x)
|
427 |
+
feat = self.conv3(self.conv2(feat))
|
428 |
+
rlt_feats = []
|
429 |
+
if return_feats:
|
430 |
+
rlt_feats.append(feat.clone())
|
431 |
+
feat = self.conv5(self.conv4(feat))
|
432 |
+
if return_feats:
|
433 |
+
rlt_feats.append(feat.clone())
|
434 |
+
out = self.final_conv(feat)
|
435 |
+
|
436 |
+
if return_feats:
|
437 |
+
return out, rlt_feats
|
438 |
+
else:
|
439 |
+
return out, None
|
gfpgan/archs/gfpganv1_clean_arch.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
9 |
+
|
10 |
+
|
11 |
+
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
12 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
13 |
+
|
14 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
out_size (int): The spatial size of outputs.
|
18 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
19 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
20 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
21 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
22 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
26 |
+
super(StyleGAN2GeneratorCSFT, self).__init__(
|
27 |
+
out_size,
|
28 |
+
num_style_feat=num_style_feat,
|
29 |
+
num_mlp=num_mlp,
|
30 |
+
channel_multiplier=channel_multiplier,
|
31 |
+
narrow=narrow)
|
32 |
+
self.sft_half = sft_half
|
33 |
+
|
34 |
+
def forward(self,
|
35 |
+
styles,
|
36 |
+
conditions,
|
37 |
+
input_is_latent=False,
|
38 |
+
noise=None,
|
39 |
+
randomize_noise=True,
|
40 |
+
truncation=1,
|
41 |
+
truncation_latent=None,
|
42 |
+
inject_index=None,
|
43 |
+
return_latents=False):
|
44 |
+
"""Forward function for StyleGAN2GeneratorCSFT.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
styles (list[Tensor]): Sample codes of styles.
|
48 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
49 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
50 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
51 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
52 |
+
truncation (float): The truncation ratio. Default: 1.
|
53 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
54 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
55 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
56 |
+
"""
|
57 |
+
# style codes -> latents with Style MLP layer
|
58 |
+
if not input_is_latent:
|
59 |
+
styles = [self.style_mlp(s) for s in styles]
|
60 |
+
# noises
|
61 |
+
if noise is None:
|
62 |
+
if randomize_noise:
|
63 |
+
noise = [None] * self.num_layers # for each style conv layer
|
64 |
+
else: # use the stored noise
|
65 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
66 |
+
# style truncation
|
67 |
+
if truncation < 1:
|
68 |
+
style_truncation = []
|
69 |
+
for style in styles:
|
70 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
71 |
+
styles = style_truncation
|
72 |
+
# get style latents with injection
|
73 |
+
if len(styles) == 1:
|
74 |
+
inject_index = self.num_latent
|
75 |
+
|
76 |
+
if styles[0].ndim < 3:
|
77 |
+
# repeat latent code for all the layers
|
78 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
79 |
+
else: # used for encoder with different latent code for each layer
|
80 |
+
latent = styles[0]
|
81 |
+
elif len(styles) == 2: # mixing noises
|
82 |
+
if inject_index is None:
|
83 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
84 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
85 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
86 |
+
latent = torch.cat([latent1, latent2], 1)
|
87 |
+
|
88 |
+
# main generation
|
89 |
+
out = self.constant_input(latent.shape[0])
|
90 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
91 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
92 |
+
|
93 |
+
i = 1
|
94 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
95 |
+
noise[2::2], self.to_rgbs):
|
96 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
97 |
+
|
98 |
+
# the conditions may have fewer levels
|
99 |
+
if i < len(conditions):
|
100 |
+
# SFT part to combine the conditions
|
101 |
+
if self.sft_half: # only apply SFT to half of the channels
|
102 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
103 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
104 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
105 |
+
else: # apply SFT to all the channels
|
106 |
+
out = out * conditions[i - 1] + conditions[i]
|
107 |
+
|
108 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
109 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
110 |
+
i += 2
|
111 |
+
|
112 |
+
image = skip
|
113 |
+
|
114 |
+
if return_latents:
|
115 |
+
return image, latent
|
116 |
+
else:
|
117 |
+
return image, None
|
118 |
+
|
119 |
+
|
120 |
+
class ResBlock(nn.Module):
|
121 |
+
"""Residual block with bilinear upsampling/downsampling.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
in_channels (int): Channel number of the input.
|
125 |
+
out_channels (int): Channel number of the output.
|
126 |
+
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
130 |
+
super(ResBlock, self).__init__()
|
131 |
+
|
132 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
133 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
134 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
135 |
+
if mode == 'down':
|
136 |
+
self.scale_factor = 0.5
|
137 |
+
elif mode == 'up':
|
138 |
+
self.scale_factor = 2
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
142 |
+
# upsample/downsample
|
143 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
144 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
145 |
+
# skip
|
146 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
147 |
+
skip = self.skip(x)
|
148 |
+
out = out + skip
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
@ARCH_REGISTRY.register()
|
153 |
+
class GFPGANv1Clean(nn.Module):
|
154 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
155 |
+
|
156 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
157 |
+
|
158 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
out_size (int): The spatial size of outputs.
|
162 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
163 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
164 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
165 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
166 |
+
|
167 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
168 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
169 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
170 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
171 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
out_size,
|
177 |
+
num_style_feat=512,
|
178 |
+
channel_multiplier=1,
|
179 |
+
decoder_load_path=None,
|
180 |
+
fix_decoder=True,
|
181 |
+
# for stylegan decoder
|
182 |
+
num_mlp=8,
|
183 |
+
input_is_latent=False,
|
184 |
+
different_w=False,
|
185 |
+
narrow=1,
|
186 |
+
sft_half=False):
|
187 |
+
|
188 |
+
super(GFPGANv1Clean, self).__init__()
|
189 |
+
self.input_is_latent = input_is_latent
|
190 |
+
self.different_w = different_w
|
191 |
+
self.num_style_feat = num_style_feat
|
192 |
+
|
193 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
194 |
+
channels = {
|
195 |
+
'4': int(512 * unet_narrow),
|
196 |
+
'8': int(512 * unet_narrow),
|
197 |
+
'16': int(512 * unet_narrow),
|
198 |
+
'32': int(512 * unet_narrow),
|
199 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
200 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
201 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
202 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
203 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
204 |
+
}
|
205 |
+
|
206 |
+
self.log_size = int(math.log(out_size, 2))
|
207 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
208 |
+
|
209 |
+
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
210 |
+
|
211 |
+
# downsample
|
212 |
+
in_channels = channels[f'{first_out_size}']
|
213 |
+
self.conv_body_down = nn.ModuleList()
|
214 |
+
for i in range(self.log_size, 2, -1):
|
215 |
+
out_channels = channels[f'{2**(i - 1)}']
|
216 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
217 |
+
in_channels = out_channels
|
218 |
+
|
219 |
+
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
220 |
+
|
221 |
+
# upsample
|
222 |
+
in_channels = channels['4']
|
223 |
+
self.conv_body_up = nn.ModuleList()
|
224 |
+
for i in range(3, self.log_size + 1):
|
225 |
+
out_channels = channels[f'{2**i}']
|
226 |
+
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
227 |
+
in_channels = out_channels
|
228 |
+
|
229 |
+
# to RGB
|
230 |
+
self.toRGB = nn.ModuleList()
|
231 |
+
for i in range(3, self.log_size + 1):
|
232 |
+
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
233 |
+
|
234 |
+
if different_w:
|
235 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
236 |
+
else:
|
237 |
+
linear_out_channel = num_style_feat
|
238 |
+
|
239 |
+
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
240 |
+
|
241 |
+
# the decoder: stylegan2 generator with SFT modulations
|
242 |
+
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
243 |
+
out_size=out_size,
|
244 |
+
num_style_feat=num_style_feat,
|
245 |
+
num_mlp=num_mlp,
|
246 |
+
channel_multiplier=channel_multiplier,
|
247 |
+
narrow=narrow,
|
248 |
+
sft_half=sft_half)
|
249 |
+
|
250 |
+
# load pre-trained stylegan2 model if necessary
|
251 |
+
if decoder_load_path:
|
252 |
+
self.stylegan_decoder.load_state_dict(
|
253 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
254 |
+
# fix decoder without updating params
|
255 |
+
if fix_decoder:
|
256 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
257 |
+
param.requires_grad = False
|
258 |
+
|
259 |
+
# for SFT modulations (scale and shift)
|
260 |
+
self.condition_scale = nn.ModuleList()
|
261 |
+
self.condition_shift = nn.ModuleList()
|
262 |
+
for i in range(3, self.log_size + 1):
|
263 |
+
out_channels = channels[f'{2**i}']
|
264 |
+
if sft_half:
|
265 |
+
sft_out_channels = out_channels
|
266 |
+
else:
|
267 |
+
sft_out_channels = out_channels * 2
|
268 |
+
self.condition_scale.append(
|
269 |
+
nn.Sequential(
|
270 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
271 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
272 |
+
self.condition_shift.append(
|
273 |
+
nn.Sequential(
|
274 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
275 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
276 |
+
|
277 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
278 |
+
"""Forward function for GFPGANv1Clean.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
x (Tensor): Input images.
|
282 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
283 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
284 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
285 |
+
"""
|
286 |
+
conditions = []
|
287 |
+
unet_skips = []
|
288 |
+
out_rgbs = []
|
289 |
+
|
290 |
+
# encoder
|
291 |
+
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
292 |
+
for i in range(self.log_size - 2):
|
293 |
+
feat = self.conv_body_down[i](feat)
|
294 |
+
unet_skips.insert(0, feat)
|
295 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
296 |
+
|
297 |
+
# style code
|
298 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
299 |
+
if self.different_w:
|
300 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
301 |
+
|
302 |
+
# decode
|
303 |
+
for i in range(self.log_size - 2):
|
304 |
+
# add unet skip
|
305 |
+
feat = feat + unet_skips[i]
|
306 |
+
# ResUpLayer
|
307 |
+
feat = self.conv_body_up[i](feat)
|
308 |
+
# generate scale and shift for SFT layers
|
309 |
+
scale = self.condition_scale[i](feat)
|
310 |
+
conditions.append(scale.clone())
|
311 |
+
shift = self.condition_shift[i](feat)
|
312 |
+
conditions.append(shift.clone())
|
313 |
+
# generate rgb images
|
314 |
+
if return_rgb:
|
315 |
+
out_rgbs.append(self.toRGB[i](feat))
|
316 |
+
|
317 |
+
# decoder
|
318 |
+
image, _ = self.stylegan_decoder([style_code],
|
319 |
+
conditions,
|
320 |
+
return_latents=return_latents,
|
321 |
+
input_is_latent=self.input_is_latent,
|
322 |
+
randomize_noise=randomize_noise)
|
323 |
+
|
324 |
+
return image, out_rgbs
|
gfpgan/archs/restoreformer_arch.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/wzhouxiff/RestoreFormer
|
2 |
+
"""
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
"""
|
11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
+
____________________________________________
|
13 |
+
Discretization bottleneck part of the VQ-VAE.
|
14 |
+
Inputs:
|
15 |
+
- n_e : number of embeddings
|
16 |
+
- e_dim : dimension of embedding
|
17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
+
_____________________________________________
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, n_e, e_dim, beta):
|
22 |
+
super(VectorQuantizer, self).__init__()
|
23 |
+
self.n_e = n_e
|
24 |
+
self.e_dim = e_dim
|
25 |
+
self.beta = beta
|
26 |
+
|
27 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
28 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
29 |
+
|
30 |
+
def forward(self, z):
|
31 |
+
"""
|
32 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
33 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
34 |
+
z (continuous) -> z_q (discrete)
|
35 |
+
z.shape = (batch, channel, height, width)
|
36 |
+
quantization pipeline:
|
37 |
+
1. get encoder input (B,C,H,W)
|
38 |
+
2. flatten input to (B*H*W,C)
|
39 |
+
"""
|
40 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
41 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
42 |
+
z_flattened = z.view(-1, self.e_dim)
|
43 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
44 |
+
|
45 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
46 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
47 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
48 |
+
|
49 |
+
# could possible replace this here
|
50 |
+
# #\start...
|
51 |
+
# find closest encodings
|
52 |
+
|
53 |
+
min_value, min_encoding_indices = torch.min(d, dim=1)
|
54 |
+
|
55 |
+
min_encoding_indices = min_encoding_indices.unsqueeze(1)
|
56 |
+
|
57 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
58 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
59 |
+
|
60 |
+
# dtype min encodings: torch.float32
|
61 |
+
# min_encodings shape: torch.Size([2048, 512])
|
62 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
63 |
+
|
64 |
+
# get quantized latent vectors
|
65 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
66 |
+
# .........\end
|
67 |
+
|
68 |
+
# with:
|
69 |
+
# .........\start
|
70 |
+
# min_encoding_indices = torch.argmin(d, dim=1)
|
71 |
+
# z_q = self.embedding(min_encoding_indices)
|
72 |
+
# ......\end......... (TODO)
|
73 |
+
|
74 |
+
# compute loss for embedding
|
75 |
+
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
|
76 |
+
|
77 |
+
# preserve gradients
|
78 |
+
z_q = z + (z_q - z).detach()
|
79 |
+
|
80 |
+
# perplexity
|
81 |
+
|
82 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
83 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
84 |
+
|
85 |
+
# reshape back to match original input shape
|
86 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
87 |
+
|
88 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
|
89 |
+
|
90 |
+
def get_codebook_entry(self, indices, shape):
|
91 |
+
# shape specifying (batch, height, width, channel)
|
92 |
+
# TODO: check for more easy handling with nn.Embedding
|
93 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
94 |
+
min_encodings.scatter_(1, indices[:, None], 1)
|
95 |
+
|
96 |
+
# get quantized latent vectors
|
97 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
98 |
+
|
99 |
+
if shape is not None:
|
100 |
+
z_q = z_q.view(shape)
|
101 |
+
|
102 |
+
# reshape back to match original input shape
|
103 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
104 |
+
|
105 |
+
return z_q
|
106 |
+
|
107 |
+
|
108 |
+
# pytorch_diffusion + derived encoder decoder
|
109 |
+
def nonlinearity(x):
|
110 |
+
# swish
|
111 |
+
return x * torch.sigmoid(x)
|
112 |
+
|
113 |
+
|
114 |
+
def Normalize(in_channels):
|
115 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
116 |
+
|
117 |
+
|
118 |
+
class Upsample(nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, in_channels, with_conv):
|
121 |
+
super().__init__()
|
122 |
+
self.with_conv = with_conv
|
123 |
+
if self.with_conv:
|
124 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
|
128 |
+
if self.with_conv:
|
129 |
+
x = self.conv(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class Downsample(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, in_channels, with_conv):
|
136 |
+
super().__init__()
|
137 |
+
self.with_conv = with_conv
|
138 |
+
if self.with_conv:
|
139 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
140 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
if self.with_conv:
|
144 |
+
pad = (0, 1, 0, 1)
|
145 |
+
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
146 |
+
x = self.conv(x)
|
147 |
+
else:
|
148 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class ResnetBlock(nn.Module):
|
153 |
+
|
154 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
155 |
+
super().__init__()
|
156 |
+
self.in_channels = in_channels
|
157 |
+
out_channels = in_channels if out_channels is None else out_channels
|
158 |
+
self.out_channels = out_channels
|
159 |
+
self.use_conv_shortcut = conv_shortcut
|
160 |
+
|
161 |
+
self.norm1 = Normalize(in_channels)
|
162 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
163 |
+
if temb_channels > 0:
|
164 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
165 |
+
self.norm2 = Normalize(out_channels)
|
166 |
+
self.dropout = torch.nn.Dropout(dropout)
|
167 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
168 |
+
if self.in_channels != self.out_channels:
|
169 |
+
if self.use_conv_shortcut:
|
170 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
171 |
+
else:
|
172 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
173 |
+
|
174 |
+
def forward(self, x, temb):
|
175 |
+
h = x
|
176 |
+
h = self.norm1(h)
|
177 |
+
h = nonlinearity(h)
|
178 |
+
h = self.conv1(h)
|
179 |
+
|
180 |
+
if temb is not None:
|
181 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
182 |
+
|
183 |
+
h = self.norm2(h)
|
184 |
+
h = nonlinearity(h)
|
185 |
+
h = self.dropout(h)
|
186 |
+
h = self.conv2(h)
|
187 |
+
|
188 |
+
if self.in_channels != self.out_channels:
|
189 |
+
if self.use_conv_shortcut:
|
190 |
+
x = self.conv_shortcut(x)
|
191 |
+
else:
|
192 |
+
x = self.nin_shortcut(x)
|
193 |
+
|
194 |
+
return x + h
|
195 |
+
|
196 |
+
|
197 |
+
class MultiHeadAttnBlock(nn.Module):
|
198 |
+
|
199 |
+
def __init__(self, in_channels, head_size=1):
|
200 |
+
super().__init__()
|
201 |
+
self.in_channels = in_channels
|
202 |
+
self.head_size = head_size
|
203 |
+
self.att_size = in_channels // head_size
|
204 |
+
assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
|
205 |
+
|
206 |
+
self.norm1 = Normalize(in_channels)
|
207 |
+
self.norm2 = Normalize(in_channels)
|
208 |
+
|
209 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
210 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
211 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
212 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
213 |
+
self.num = 0
|
214 |
+
|
215 |
+
def forward(self, x, y=None):
|
216 |
+
h_ = x
|
217 |
+
h_ = self.norm1(h_)
|
218 |
+
if y is None:
|
219 |
+
y = h_
|
220 |
+
else:
|
221 |
+
y = self.norm2(y)
|
222 |
+
|
223 |
+
q = self.q(y)
|
224 |
+
k = self.k(h_)
|
225 |
+
v = self.v(h_)
|
226 |
+
|
227 |
+
# compute attention
|
228 |
+
b, c, h, w = q.shape
|
229 |
+
q = q.reshape(b, self.head_size, self.att_size, h * w)
|
230 |
+
q = q.permute(0, 3, 1, 2) # b, hw, head, att
|
231 |
+
|
232 |
+
k = k.reshape(b, self.head_size, self.att_size, h * w)
|
233 |
+
k = k.permute(0, 3, 1, 2)
|
234 |
+
|
235 |
+
v = v.reshape(b, self.head_size, self.att_size, h * w)
|
236 |
+
v = v.permute(0, 3, 1, 2)
|
237 |
+
|
238 |
+
q = q.transpose(1, 2)
|
239 |
+
v = v.transpose(1, 2)
|
240 |
+
k = k.transpose(1, 2).transpose(2, 3)
|
241 |
+
|
242 |
+
scale = int(self.att_size)**(-0.5)
|
243 |
+
q.mul_(scale)
|
244 |
+
w_ = torch.matmul(q, k)
|
245 |
+
w_ = F.softmax(w_, dim=3)
|
246 |
+
|
247 |
+
w_ = w_.matmul(v)
|
248 |
+
|
249 |
+
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
|
250 |
+
w_ = w_.view(b, h, w, -1)
|
251 |
+
w_ = w_.permute(0, 3, 1, 2)
|
252 |
+
|
253 |
+
w_ = self.proj_out(w_)
|
254 |
+
|
255 |
+
return x + w_
|
256 |
+
|
257 |
+
|
258 |
+
class MultiHeadEncoder(nn.Module):
|
259 |
+
|
260 |
+
def __init__(self,
|
261 |
+
ch,
|
262 |
+
out_ch,
|
263 |
+
ch_mult=(1, 2, 4, 8),
|
264 |
+
num_res_blocks=2,
|
265 |
+
attn_resolutions=(16, ),
|
266 |
+
dropout=0.0,
|
267 |
+
resamp_with_conv=True,
|
268 |
+
in_channels=3,
|
269 |
+
resolution=512,
|
270 |
+
z_channels=256,
|
271 |
+
double_z=True,
|
272 |
+
enable_mid=True,
|
273 |
+
head_size=1,
|
274 |
+
**ignore_kwargs):
|
275 |
+
super().__init__()
|
276 |
+
self.ch = ch
|
277 |
+
self.temb_ch = 0
|
278 |
+
self.num_resolutions = len(ch_mult)
|
279 |
+
self.num_res_blocks = num_res_blocks
|
280 |
+
self.resolution = resolution
|
281 |
+
self.in_channels = in_channels
|
282 |
+
self.enable_mid = enable_mid
|
283 |
+
|
284 |
+
# downsampling
|
285 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
286 |
+
|
287 |
+
curr_res = resolution
|
288 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
289 |
+
self.down = nn.ModuleList()
|
290 |
+
for i_level in range(self.num_resolutions):
|
291 |
+
block = nn.ModuleList()
|
292 |
+
attn = nn.ModuleList()
|
293 |
+
block_in = ch * in_ch_mult[i_level]
|
294 |
+
block_out = ch * ch_mult[i_level]
|
295 |
+
for i_block in range(self.num_res_blocks):
|
296 |
+
block.append(
|
297 |
+
ResnetBlock(
|
298 |
+
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
299 |
+
block_in = block_out
|
300 |
+
if curr_res in attn_resolutions:
|
301 |
+
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
302 |
+
down = nn.Module()
|
303 |
+
down.block = block
|
304 |
+
down.attn = attn
|
305 |
+
if i_level != self.num_resolutions - 1:
|
306 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
307 |
+
curr_res = curr_res // 2
|
308 |
+
self.down.append(down)
|
309 |
+
|
310 |
+
# middle
|
311 |
+
if self.enable_mid:
|
312 |
+
self.mid = nn.Module()
|
313 |
+
self.mid.block_1 = ResnetBlock(
|
314 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
315 |
+
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
316 |
+
self.mid.block_2 = ResnetBlock(
|
317 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
318 |
+
|
319 |
+
# end
|
320 |
+
self.norm_out = Normalize(block_in)
|
321 |
+
self.conv_out = torch.nn.Conv2d(
|
322 |
+
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
hs = {}
|
326 |
+
# timestep embedding
|
327 |
+
temb = None
|
328 |
+
|
329 |
+
# downsampling
|
330 |
+
h = self.conv_in(x)
|
331 |
+
hs['in'] = h
|
332 |
+
for i_level in range(self.num_resolutions):
|
333 |
+
for i_block in range(self.num_res_blocks):
|
334 |
+
h = self.down[i_level].block[i_block](h, temb)
|
335 |
+
if len(self.down[i_level].attn) > 0:
|
336 |
+
h = self.down[i_level].attn[i_block](h)
|
337 |
+
|
338 |
+
if i_level != self.num_resolutions - 1:
|
339 |
+
# hs.append(h)
|
340 |
+
hs['block_' + str(i_level)] = h
|
341 |
+
h = self.down[i_level].downsample(h)
|
342 |
+
|
343 |
+
# middle
|
344 |
+
# h = hs[-1]
|
345 |
+
if self.enable_mid:
|
346 |
+
h = self.mid.block_1(h, temb)
|
347 |
+
hs['block_' + str(i_level) + '_atten'] = h
|
348 |
+
h = self.mid.attn_1(h)
|
349 |
+
h = self.mid.block_2(h, temb)
|
350 |
+
hs['mid_atten'] = h
|
351 |
+
|
352 |
+
# end
|
353 |
+
h = self.norm_out(h)
|
354 |
+
h = nonlinearity(h)
|
355 |
+
h = self.conv_out(h)
|
356 |
+
# hs.append(h)
|
357 |
+
hs['out'] = h
|
358 |
+
|
359 |
+
return hs
|
360 |
+
|
361 |
+
|
362 |
+
class MultiHeadDecoder(nn.Module):
|
363 |
+
|
364 |
+
def __init__(self,
|
365 |
+
ch,
|
366 |
+
out_ch,
|
367 |
+
ch_mult=(1, 2, 4, 8),
|
368 |
+
num_res_blocks=2,
|
369 |
+
attn_resolutions=(16, ),
|
370 |
+
dropout=0.0,
|
371 |
+
resamp_with_conv=True,
|
372 |
+
in_channels=3,
|
373 |
+
resolution=512,
|
374 |
+
z_channels=256,
|
375 |
+
give_pre_end=False,
|
376 |
+
enable_mid=True,
|
377 |
+
head_size=1,
|
378 |
+
**ignorekwargs):
|
379 |
+
super().__init__()
|
380 |
+
self.ch = ch
|
381 |
+
self.temb_ch = 0
|
382 |
+
self.num_resolutions = len(ch_mult)
|
383 |
+
self.num_res_blocks = num_res_blocks
|
384 |
+
self.resolution = resolution
|
385 |
+
self.in_channels = in_channels
|
386 |
+
self.give_pre_end = give_pre_end
|
387 |
+
self.enable_mid = enable_mid
|
388 |
+
|
389 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
390 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
391 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
392 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
393 |
+
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
|
394 |
+
|
395 |
+
# z to block_in
|
396 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
397 |
+
|
398 |
+
# middle
|
399 |
+
if self.enable_mid:
|
400 |
+
self.mid = nn.Module()
|
401 |
+
self.mid.block_1 = ResnetBlock(
|
402 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
403 |
+
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
404 |
+
self.mid.block_2 = ResnetBlock(
|
405 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
406 |
+
|
407 |
+
# upsampling
|
408 |
+
self.up = nn.ModuleList()
|
409 |
+
for i_level in reversed(range(self.num_resolutions)):
|
410 |
+
block = nn.ModuleList()
|
411 |
+
attn = nn.ModuleList()
|
412 |
+
block_out = ch * ch_mult[i_level]
|
413 |
+
for i_block in range(self.num_res_blocks + 1):
|
414 |
+
block.append(
|
415 |
+
ResnetBlock(
|
416 |
+
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
417 |
+
block_in = block_out
|
418 |
+
if curr_res in attn_resolutions:
|
419 |
+
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
420 |
+
up = nn.Module()
|
421 |
+
up.block = block
|
422 |
+
up.attn = attn
|
423 |
+
if i_level != 0:
|
424 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
425 |
+
curr_res = curr_res * 2
|
426 |
+
self.up.insert(0, up) # prepend to get consistent order
|
427 |
+
|
428 |
+
# end
|
429 |
+
self.norm_out = Normalize(block_in)
|
430 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
431 |
+
|
432 |
+
def forward(self, z):
|
433 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
434 |
+
self.last_z_shape = z.shape
|
435 |
+
|
436 |
+
# timestep embedding
|
437 |
+
temb = None
|
438 |
+
|
439 |
+
# z to block_in
|
440 |
+
h = self.conv_in(z)
|
441 |
+
|
442 |
+
# middle
|
443 |
+
if self.enable_mid:
|
444 |
+
h = self.mid.block_1(h, temb)
|
445 |
+
h = self.mid.attn_1(h)
|
446 |
+
h = self.mid.block_2(h, temb)
|
447 |
+
|
448 |
+
# upsampling
|
449 |
+
for i_level in reversed(range(self.num_resolutions)):
|
450 |
+
for i_block in range(self.num_res_blocks + 1):
|
451 |
+
h = self.up[i_level].block[i_block](h, temb)
|
452 |
+
if len(self.up[i_level].attn) > 0:
|
453 |
+
h = self.up[i_level].attn[i_block](h)
|
454 |
+
if i_level != 0:
|
455 |
+
h = self.up[i_level].upsample(h)
|
456 |
+
|
457 |
+
# end
|
458 |
+
if self.give_pre_end:
|
459 |
+
return h
|
460 |
+
|
461 |
+
h = self.norm_out(h)
|
462 |
+
h = nonlinearity(h)
|
463 |
+
h = self.conv_out(h)
|
464 |
+
return h
|
465 |
+
|
466 |
+
|
467 |
+
class MultiHeadDecoderTransformer(nn.Module):
|
468 |
+
|
469 |
+
def __init__(self,
|
470 |
+
ch,
|
471 |
+
out_ch,
|
472 |
+
ch_mult=(1, 2, 4, 8),
|
473 |
+
num_res_blocks=2,
|
474 |
+
attn_resolutions=(16, ),
|
475 |
+
dropout=0.0,
|
476 |
+
resamp_with_conv=True,
|
477 |
+
in_channels=3,
|
478 |
+
resolution=512,
|
479 |
+
z_channels=256,
|
480 |
+
give_pre_end=False,
|
481 |
+
enable_mid=True,
|
482 |
+
head_size=1,
|
483 |
+
**ignorekwargs):
|
484 |
+
super().__init__()
|
485 |
+
self.ch = ch
|
486 |
+
self.temb_ch = 0
|
487 |
+
self.num_resolutions = len(ch_mult)
|
488 |
+
self.num_res_blocks = num_res_blocks
|
489 |
+
self.resolution = resolution
|
490 |
+
self.in_channels = in_channels
|
491 |
+
self.give_pre_end = give_pre_end
|
492 |
+
self.enable_mid = enable_mid
|
493 |
+
|
494 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
495 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
496 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
497 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
498 |
+
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
|
499 |
+
|
500 |
+
# z to block_in
|
501 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
502 |
+
|
503 |
+
# middle
|
504 |
+
if self.enable_mid:
|
505 |
+
self.mid = nn.Module()
|
506 |
+
self.mid.block_1 = ResnetBlock(
|
507 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
508 |
+
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
509 |
+
self.mid.block_2 = ResnetBlock(
|
510 |
+
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
511 |
+
|
512 |
+
# upsampling
|
513 |
+
self.up = nn.ModuleList()
|
514 |
+
for i_level in reversed(range(self.num_resolutions)):
|
515 |
+
block = nn.ModuleList()
|
516 |
+
attn = nn.ModuleList()
|
517 |
+
block_out = ch * ch_mult[i_level]
|
518 |
+
for i_block in range(self.num_res_blocks + 1):
|
519 |
+
block.append(
|
520 |
+
ResnetBlock(
|
521 |
+
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
522 |
+
block_in = block_out
|
523 |
+
if curr_res in attn_resolutions:
|
524 |
+
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
525 |
+
up = nn.Module()
|
526 |
+
up.block = block
|
527 |
+
up.attn = attn
|
528 |
+
if i_level != 0:
|
529 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
530 |
+
curr_res = curr_res * 2
|
531 |
+
self.up.insert(0, up) # prepend to get consistent order
|
532 |
+
|
533 |
+
# end
|
534 |
+
self.norm_out = Normalize(block_in)
|
535 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
536 |
+
|
537 |
+
def forward(self, z, hs):
|
538 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
539 |
+
# self.last_z_shape = z.shape
|
540 |
+
|
541 |
+
# timestep embedding
|
542 |
+
temb = None
|
543 |
+
|
544 |
+
# z to block_in
|
545 |
+
h = self.conv_in(z)
|
546 |
+
|
547 |
+
# middle
|
548 |
+
if self.enable_mid:
|
549 |
+
h = self.mid.block_1(h, temb)
|
550 |
+
h = self.mid.attn_1(h, hs['mid_atten'])
|
551 |
+
h = self.mid.block_2(h, temb)
|
552 |
+
|
553 |
+
# upsampling
|
554 |
+
for i_level in reversed(range(self.num_resolutions)):
|
555 |
+
for i_block in range(self.num_res_blocks + 1):
|
556 |
+
h = self.up[i_level].block[i_block](h, temb)
|
557 |
+
if len(self.up[i_level].attn) > 0:
|
558 |
+
h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
|
559 |
+
# hfeature = h.clone()
|
560 |
+
if i_level != 0:
|
561 |
+
h = self.up[i_level].upsample(h)
|
562 |
+
|
563 |
+
# end
|
564 |
+
if self.give_pre_end:
|
565 |
+
return h
|
566 |
+
|
567 |
+
h = self.norm_out(h)
|
568 |
+
h = nonlinearity(h)
|
569 |
+
h = self.conv_out(h)
|
570 |
+
return h
|
571 |
+
|
572 |
+
|
573 |
+
class RestoreFormer(nn.Module):
|
574 |
+
|
575 |
+
def __init__(self,
|
576 |
+
n_embed=1024,
|
577 |
+
embed_dim=256,
|
578 |
+
ch=64,
|
579 |
+
out_ch=3,
|
580 |
+
ch_mult=(1, 2, 2, 4, 4, 8),
|
581 |
+
num_res_blocks=2,
|
582 |
+
attn_resolutions=(16, ),
|
583 |
+
dropout=0.0,
|
584 |
+
in_channels=3,
|
585 |
+
resolution=512,
|
586 |
+
z_channels=256,
|
587 |
+
double_z=False,
|
588 |
+
enable_mid=True,
|
589 |
+
fix_decoder=False,
|
590 |
+
fix_codebook=True,
|
591 |
+
fix_encoder=False,
|
592 |
+
head_size=8):
|
593 |
+
super(RestoreFormer, self).__init__()
|
594 |
+
|
595 |
+
self.encoder = MultiHeadEncoder(
|
596 |
+
ch=ch,
|
597 |
+
out_ch=out_ch,
|
598 |
+
ch_mult=ch_mult,
|
599 |
+
num_res_blocks=num_res_blocks,
|
600 |
+
attn_resolutions=attn_resolutions,
|
601 |
+
dropout=dropout,
|
602 |
+
in_channels=in_channels,
|
603 |
+
resolution=resolution,
|
604 |
+
z_channels=z_channels,
|
605 |
+
double_z=double_z,
|
606 |
+
enable_mid=enable_mid,
|
607 |
+
head_size=head_size)
|
608 |
+
self.decoder = MultiHeadDecoderTransformer(
|
609 |
+
ch=ch,
|
610 |
+
out_ch=out_ch,
|
611 |
+
ch_mult=ch_mult,
|
612 |
+
num_res_blocks=num_res_blocks,
|
613 |
+
attn_resolutions=attn_resolutions,
|
614 |
+
dropout=dropout,
|
615 |
+
in_channels=in_channels,
|
616 |
+
resolution=resolution,
|
617 |
+
z_channels=z_channels,
|
618 |
+
enable_mid=enable_mid,
|
619 |
+
head_size=head_size)
|
620 |
+
|
621 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
|
622 |
+
|
623 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
624 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
625 |
+
|
626 |
+
if fix_decoder:
|
627 |
+
for _, param in self.decoder.named_parameters():
|
628 |
+
param.requires_grad = False
|
629 |
+
for _, param in self.post_quant_conv.named_parameters():
|
630 |
+
param.requires_grad = False
|
631 |
+
for _, param in self.quantize.named_parameters():
|
632 |
+
param.requires_grad = False
|
633 |
+
elif fix_codebook:
|
634 |
+
for _, param in self.quantize.named_parameters():
|
635 |
+
param.requires_grad = False
|
636 |
+
|
637 |
+
if fix_encoder:
|
638 |
+
for _, param in self.encoder.named_parameters():
|
639 |
+
param.requires_grad = False
|
640 |
+
|
641 |
+
def encode(self, x):
|
642 |
+
|
643 |
+
hs = self.encoder(x)
|
644 |
+
h = self.quant_conv(hs['out'])
|
645 |
+
quant, emb_loss, info = self.quantize(h)
|
646 |
+
return quant, emb_loss, info, hs
|
647 |
+
|
648 |
+
def decode(self, quant, hs):
|
649 |
+
quant = self.post_quant_conv(quant)
|
650 |
+
dec = self.decoder(quant, hs)
|
651 |
+
|
652 |
+
return dec
|
653 |
+
|
654 |
+
def forward(self, input, **kwargs):
|
655 |
+
quant, diff, info, hs = self.encode(input)
|
656 |
+
dec = self.decode(quant, hs)
|
657 |
+
|
658 |
+
return dec, None
|
gfpgan/archs/stylegan2_bilinear_arch.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class NormStyleCode(nn.Module):
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
"""Normalize the style codes.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (Tensor): Style codes with shape (b, c).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
Tensor: Normalized tensor.
|
20 |
+
"""
|
21 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
22 |
+
|
23 |
+
|
24 |
+
class EqualLinear(nn.Module):
|
25 |
+
"""Equalized Linear as StyleGAN2.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
in_channels (int): Size of each sample.
|
29 |
+
out_channels (int): Size of each output sample.
|
30 |
+
bias (bool): If set to ``False``, the layer will not learn an additive
|
31 |
+
bias. Default: ``True``.
|
32 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
33 |
+
lr_mul (float): Learning rate multiplier. Default: 1.
|
34 |
+
activation (None | str): The activation after ``linear`` operation.
|
35 |
+
Supported: 'fused_lrelu', None. Default: None.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
39 |
+
super(EqualLinear, self).__init__()
|
40 |
+
self.in_channels = in_channels
|
41 |
+
self.out_channels = out_channels
|
42 |
+
self.lr_mul = lr_mul
|
43 |
+
self.activation = activation
|
44 |
+
if self.activation not in ['fused_lrelu', None]:
|
45 |
+
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
46 |
+
"Supported ones are: ['fused_lrelu', None].")
|
47 |
+
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
48 |
+
|
49 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
50 |
+
if bias:
|
51 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
52 |
+
else:
|
53 |
+
self.register_parameter('bias', None)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
if self.bias is None:
|
57 |
+
bias = None
|
58 |
+
else:
|
59 |
+
bias = self.bias * self.lr_mul
|
60 |
+
if self.activation == 'fused_lrelu':
|
61 |
+
out = F.linear(x, self.weight * self.scale)
|
62 |
+
out = fused_leaky_relu(out, bias)
|
63 |
+
else:
|
64 |
+
out = F.linear(x, self.weight * self.scale, bias=bias)
|
65 |
+
return out
|
66 |
+
|
67 |
+
def __repr__(self):
|
68 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
69 |
+
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
70 |
+
|
71 |
+
|
72 |
+
class ModulatedConv2d(nn.Module):
|
73 |
+
"""Modulated Conv2d used in StyleGAN2.
|
74 |
+
|
75 |
+
There is no bias in ModulatedConv2d.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
in_channels (int): Channel number of the input.
|
79 |
+
out_channels (int): Channel number of the output.
|
80 |
+
kernel_size (int): Size of the convolving kernel.
|
81 |
+
num_style_feat (int): Channel number of style features.
|
82 |
+
demodulate (bool): Whether to demodulate in the conv layer.
|
83 |
+
Default: True.
|
84 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
85 |
+
Default: None.
|
86 |
+
eps (float): A value added to the denominator for numerical stability.
|
87 |
+
Default: 1e-8.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self,
|
91 |
+
in_channels,
|
92 |
+
out_channels,
|
93 |
+
kernel_size,
|
94 |
+
num_style_feat,
|
95 |
+
demodulate=True,
|
96 |
+
sample_mode=None,
|
97 |
+
eps=1e-8,
|
98 |
+
interpolation_mode='bilinear'):
|
99 |
+
super(ModulatedConv2d, self).__init__()
|
100 |
+
self.in_channels = in_channels
|
101 |
+
self.out_channels = out_channels
|
102 |
+
self.kernel_size = kernel_size
|
103 |
+
self.demodulate = demodulate
|
104 |
+
self.sample_mode = sample_mode
|
105 |
+
self.eps = eps
|
106 |
+
self.interpolation_mode = interpolation_mode
|
107 |
+
if self.interpolation_mode == 'nearest':
|
108 |
+
self.align_corners = None
|
109 |
+
else:
|
110 |
+
self.align_corners = False
|
111 |
+
|
112 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
113 |
+
# modulation inside each modulated conv
|
114 |
+
self.modulation = EqualLinear(
|
115 |
+
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
116 |
+
|
117 |
+
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
118 |
+
self.padding = kernel_size // 2
|
119 |
+
|
120 |
+
def forward(self, x, style):
|
121 |
+
"""Forward function.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
125 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tensor: Modulated tensor after convolution.
|
129 |
+
"""
|
130 |
+
b, c, h, w = x.shape # c = c_in
|
131 |
+
# weight modulation
|
132 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
133 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
134 |
+
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
135 |
+
|
136 |
+
if self.demodulate:
|
137 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
138 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
139 |
+
|
140 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
141 |
+
|
142 |
+
if self.sample_mode == 'upsample':
|
143 |
+
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
144 |
+
elif self.sample_mode == 'downsample':
|
145 |
+
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
|
146 |
+
|
147 |
+
b, c, h, w = x.shape
|
148 |
+
x = x.view(1, b * c, h, w)
|
149 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
150 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
151 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
152 |
+
|
153 |
+
return out
|
154 |
+
|
155 |
+
def __repr__(self):
|
156 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
157 |
+
f'out_channels={self.out_channels}, '
|
158 |
+
f'kernel_size={self.kernel_size}, '
|
159 |
+
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
160 |
+
|
161 |
+
|
162 |
+
class StyleConv(nn.Module):
|
163 |
+
"""Style conv.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
in_channels (int): Channel number of the input.
|
167 |
+
out_channels (int): Channel number of the output.
|
168 |
+
kernel_size (int): Size of the convolving kernel.
|
169 |
+
num_style_feat (int): Channel number of style features.
|
170 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
171 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
172 |
+
Default: None.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self,
|
176 |
+
in_channels,
|
177 |
+
out_channels,
|
178 |
+
kernel_size,
|
179 |
+
num_style_feat,
|
180 |
+
demodulate=True,
|
181 |
+
sample_mode=None,
|
182 |
+
interpolation_mode='bilinear'):
|
183 |
+
super(StyleConv, self).__init__()
|
184 |
+
self.modulated_conv = ModulatedConv2d(
|
185 |
+
in_channels,
|
186 |
+
out_channels,
|
187 |
+
kernel_size,
|
188 |
+
num_style_feat,
|
189 |
+
demodulate=demodulate,
|
190 |
+
sample_mode=sample_mode,
|
191 |
+
interpolation_mode=interpolation_mode)
|
192 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
193 |
+
self.activate = FusedLeakyReLU(out_channels)
|
194 |
+
|
195 |
+
def forward(self, x, style, noise=None):
|
196 |
+
# modulate
|
197 |
+
out = self.modulated_conv(x, style)
|
198 |
+
# noise injection
|
199 |
+
if noise is None:
|
200 |
+
b, _, h, w = out.shape
|
201 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
202 |
+
out = out + self.weight * noise
|
203 |
+
# activation (with bias)
|
204 |
+
out = self.activate(out)
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class ToRGB(nn.Module):
|
209 |
+
"""To RGB from features.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
in_channels (int): Channel number of input.
|
213 |
+
num_style_feat (int): Channel number of style features.
|
214 |
+
upsample (bool): Whether to upsample. Default: True.
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
|
218 |
+
super(ToRGB, self).__init__()
|
219 |
+
self.upsample = upsample
|
220 |
+
self.interpolation_mode = interpolation_mode
|
221 |
+
if self.interpolation_mode == 'nearest':
|
222 |
+
self.align_corners = None
|
223 |
+
else:
|
224 |
+
self.align_corners = False
|
225 |
+
self.modulated_conv = ModulatedConv2d(
|
226 |
+
in_channels,
|
227 |
+
3,
|
228 |
+
kernel_size=1,
|
229 |
+
num_style_feat=num_style_feat,
|
230 |
+
demodulate=False,
|
231 |
+
sample_mode=None,
|
232 |
+
interpolation_mode=interpolation_mode)
|
233 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
234 |
+
|
235 |
+
def forward(self, x, style, skip=None):
|
236 |
+
"""Forward function.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
240 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
241 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
Tensor: RGB images.
|
245 |
+
"""
|
246 |
+
out = self.modulated_conv(x, style)
|
247 |
+
out = out + self.bias
|
248 |
+
if skip is not None:
|
249 |
+
if self.upsample:
|
250 |
+
skip = F.interpolate(
|
251 |
+
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
252 |
+
out = out + skip
|
253 |
+
return out
|
254 |
+
|
255 |
+
|
256 |
+
class ConstantInput(nn.Module):
|
257 |
+
"""Constant input.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
num_channel (int): Channel number of constant input.
|
261 |
+
size (int): Spatial size of constant input.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, num_channel, size):
|
265 |
+
super(ConstantInput, self).__init__()
|
266 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
267 |
+
|
268 |
+
def forward(self, batch):
|
269 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
270 |
+
return out
|
271 |
+
|
272 |
+
|
273 |
+
@ARCH_REGISTRY.register()
|
274 |
+
class StyleGAN2GeneratorBilinear(nn.Module):
|
275 |
+
"""StyleGAN2 Generator.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
out_size (int): The spatial size of outputs.
|
279 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
280 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
281 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
282 |
+
StyleGAN2. Default: 2.
|
283 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
284 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(self,
|
288 |
+
out_size,
|
289 |
+
num_style_feat=512,
|
290 |
+
num_mlp=8,
|
291 |
+
channel_multiplier=2,
|
292 |
+
lr_mlp=0.01,
|
293 |
+
narrow=1,
|
294 |
+
interpolation_mode='bilinear'):
|
295 |
+
super(StyleGAN2GeneratorBilinear, self).__init__()
|
296 |
+
# Style MLP layers
|
297 |
+
self.num_style_feat = num_style_feat
|
298 |
+
style_mlp_layers = [NormStyleCode()]
|
299 |
+
for i in range(num_mlp):
|
300 |
+
style_mlp_layers.append(
|
301 |
+
EqualLinear(
|
302 |
+
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
303 |
+
activation='fused_lrelu'))
|
304 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
305 |
+
|
306 |
+
channels = {
|
307 |
+
'4': int(512 * narrow),
|
308 |
+
'8': int(512 * narrow),
|
309 |
+
'16': int(512 * narrow),
|
310 |
+
'32': int(512 * narrow),
|
311 |
+
'64': int(256 * channel_multiplier * narrow),
|
312 |
+
'128': int(128 * channel_multiplier * narrow),
|
313 |
+
'256': int(64 * channel_multiplier * narrow),
|
314 |
+
'512': int(32 * channel_multiplier * narrow),
|
315 |
+
'1024': int(16 * channel_multiplier * narrow)
|
316 |
+
}
|
317 |
+
self.channels = channels
|
318 |
+
|
319 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
320 |
+
self.style_conv1 = StyleConv(
|
321 |
+
channels['4'],
|
322 |
+
channels['4'],
|
323 |
+
kernel_size=3,
|
324 |
+
num_style_feat=num_style_feat,
|
325 |
+
demodulate=True,
|
326 |
+
sample_mode=None,
|
327 |
+
interpolation_mode=interpolation_mode)
|
328 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
|
329 |
+
|
330 |
+
self.log_size = int(math.log(out_size, 2))
|
331 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
332 |
+
self.num_latent = self.log_size * 2 - 2
|
333 |
+
|
334 |
+
self.style_convs = nn.ModuleList()
|
335 |
+
self.to_rgbs = nn.ModuleList()
|
336 |
+
self.noises = nn.Module()
|
337 |
+
|
338 |
+
in_channels = channels['4']
|
339 |
+
# noise
|
340 |
+
for layer_idx in range(self.num_layers):
|
341 |
+
resolution = 2**((layer_idx + 5) // 2)
|
342 |
+
shape = [1, 1, resolution, resolution]
|
343 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
344 |
+
# style convs and to_rgbs
|
345 |
+
for i in range(3, self.log_size + 1):
|
346 |
+
out_channels = channels[f'{2**i}']
|
347 |
+
self.style_convs.append(
|
348 |
+
StyleConv(
|
349 |
+
in_channels,
|
350 |
+
out_channels,
|
351 |
+
kernel_size=3,
|
352 |
+
num_style_feat=num_style_feat,
|
353 |
+
demodulate=True,
|
354 |
+
sample_mode='upsample',
|
355 |
+
interpolation_mode=interpolation_mode))
|
356 |
+
self.style_convs.append(
|
357 |
+
StyleConv(
|
358 |
+
out_channels,
|
359 |
+
out_channels,
|
360 |
+
kernel_size=3,
|
361 |
+
num_style_feat=num_style_feat,
|
362 |
+
demodulate=True,
|
363 |
+
sample_mode=None,
|
364 |
+
interpolation_mode=interpolation_mode))
|
365 |
+
self.to_rgbs.append(
|
366 |
+
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
|
367 |
+
in_channels = out_channels
|
368 |
+
|
369 |
+
def make_noise(self):
|
370 |
+
"""Make noise for noise injection."""
|
371 |
+
device = self.constant_input.weight.device
|
372 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
373 |
+
|
374 |
+
for i in range(3, self.log_size + 1):
|
375 |
+
for _ in range(2):
|
376 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
377 |
+
|
378 |
+
return noises
|
379 |
+
|
380 |
+
def get_latent(self, x):
|
381 |
+
return self.style_mlp(x)
|
382 |
+
|
383 |
+
def mean_latent(self, num_latent):
|
384 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
385 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
386 |
+
return latent
|
387 |
+
|
388 |
+
def forward(self,
|
389 |
+
styles,
|
390 |
+
input_is_latent=False,
|
391 |
+
noise=None,
|
392 |
+
randomize_noise=True,
|
393 |
+
truncation=1,
|
394 |
+
truncation_latent=None,
|
395 |
+
inject_index=None,
|
396 |
+
return_latents=False):
|
397 |
+
"""Forward function for StyleGAN2Generator.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
styles (list[Tensor]): Sample codes of styles.
|
401 |
+
input_is_latent (bool): Whether input is latent style.
|
402 |
+
Default: False.
|
403 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
404 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is
|
405 |
+
False. Default: True.
|
406 |
+
truncation (float): TODO. Default: 1.
|
407 |
+
truncation_latent (Tensor | None): TODO. Default: None.
|
408 |
+
inject_index (int | None): The injection index for mixing noise.
|
409 |
+
Default: None.
|
410 |
+
return_latents (bool): Whether to return style latents.
|
411 |
+
Default: False.
|
412 |
+
"""
|
413 |
+
# style codes -> latents with Style MLP layer
|
414 |
+
if not input_is_latent:
|
415 |
+
styles = [self.style_mlp(s) for s in styles]
|
416 |
+
# noises
|
417 |
+
if noise is None:
|
418 |
+
if randomize_noise:
|
419 |
+
noise = [None] * self.num_layers # for each style conv layer
|
420 |
+
else: # use the stored noise
|
421 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
422 |
+
# style truncation
|
423 |
+
if truncation < 1:
|
424 |
+
style_truncation = []
|
425 |
+
for style in styles:
|
426 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
427 |
+
styles = style_truncation
|
428 |
+
# get style latent with injection
|
429 |
+
if len(styles) == 1:
|
430 |
+
inject_index = self.num_latent
|
431 |
+
|
432 |
+
if styles[0].ndim < 3:
|
433 |
+
# repeat latent code for all the layers
|
434 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
435 |
+
else: # used for encoder with different latent code for each layer
|
436 |
+
latent = styles[0]
|
437 |
+
elif len(styles) == 2: # mixing noises
|
438 |
+
if inject_index is None:
|
439 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
440 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
441 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
442 |
+
latent = torch.cat([latent1, latent2], 1)
|
443 |
+
|
444 |
+
# main generation
|
445 |
+
out = self.constant_input(latent.shape[0])
|
446 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
447 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
448 |
+
|
449 |
+
i = 1
|
450 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
451 |
+
noise[2::2], self.to_rgbs):
|
452 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
453 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
454 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
455 |
+
i += 2
|
456 |
+
|
457 |
+
image = skip
|
458 |
+
|
459 |
+
if return_latents:
|
460 |
+
return image, latent
|
461 |
+
else:
|
462 |
+
return image, None
|
463 |
+
|
464 |
+
|
465 |
+
class ScaledLeakyReLU(nn.Module):
|
466 |
+
"""Scaled LeakyReLU.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
negative_slope (float): Negative slope. Default: 0.2.
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __init__(self, negative_slope=0.2):
|
473 |
+
super(ScaledLeakyReLU, self).__init__()
|
474 |
+
self.negative_slope = negative_slope
|
475 |
+
|
476 |
+
def forward(self, x):
|
477 |
+
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
478 |
+
return out * math.sqrt(2)
|
479 |
+
|
480 |
+
|
481 |
+
class EqualConv2d(nn.Module):
|
482 |
+
"""Equalized Linear as StyleGAN2.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
in_channels (int): Channel number of the input.
|
486 |
+
out_channels (int): Channel number of the output.
|
487 |
+
kernel_size (int): Size of the convolving kernel.
|
488 |
+
stride (int): Stride of the convolution. Default: 1
|
489 |
+
padding (int): Zero-padding added to both sides of the input.
|
490 |
+
Default: 0.
|
491 |
+
bias (bool): If ``True``, adds a learnable bias to the output.
|
492 |
+
Default: ``True``.
|
493 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
494 |
+
"""
|
495 |
+
|
496 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
497 |
+
super(EqualConv2d, self).__init__()
|
498 |
+
self.in_channels = in_channels
|
499 |
+
self.out_channels = out_channels
|
500 |
+
self.kernel_size = kernel_size
|
501 |
+
self.stride = stride
|
502 |
+
self.padding = padding
|
503 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
504 |
+
|
505 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
506 |
+
if bias:
|
507 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
508 |
+
else:
|
509 |
+
self.register_parameter('bias', None)
|
510 |
+
|
511 |
+
def forward(self, x):
|
512 |
+
out = F.conv2d(
|
513 |
+
x,
|
514 |
+
self.weight * self.scale,
|
515 |
+
bias=self.bias,
|
516 |
+
stride=self.stride,
|
517 |
+
padding=self.padding,
|
518 |
+
)
|
519 |
+
|
520 |
+
return out
|
521 |
+
|
522 |
+
def __repr__(self):
|
523 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
524 |
+
f'out_channels={self.out_channels}, '
|
525 |
+
f'kernel_size={self.kernel_size},'
|
526 |
+
f' stride={self.stride}, padding={self.padding}, '
|
527 |
+
f'bias={self.bias is not None})')
|
528 |
+
|
529 |
+
|
530 |
+
class ConvLayer(nn.Sequential):
|
531 |
+
"""Conv Layer used in StyleGAN2 Discriminator.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
in_channels (int): Channel number of the input.
|
535 |
+
out_channels (int): Channel number of the output.
|
536 |
+
kernel_size (int): Kernel size.
|
537 |
+
downsample (bool): Whether downsample by a factor of 2.
|
538 |
+
Default: False.
|
539 |
+
bias (bool): Whether with bias. Default: True.
|
540 |
+
activate (bool): Whether use activateion. Default: True.
|
541 |
+
"""
|
542 |
+
|
543 |
+
def __init__(self,
|
544 |
+
in_channels,
|
545 |
+
out_channels,
|
546 |
+
kernel_size,
|
547 |
+
downsample=False,
|
548 |
+
bias=True,
|
549 |
+
activate=True,
|
550 |
+
interpolation_mode='bilinear'):
|
551 |
+
layers = []
|
552 |
+
self.interpolation_mode = interpolation_mode
|
553 |
+
# downsample
|
554 |
+
if downsample:
|
555 |
+
if self.interpolation_mode == 'nearest':
|
556 |
+
self.align_corners = None
|
557 |
+
else:
|
558 |
+
self.align_corners = False
|
559 |
+
|
560 |
+
layers.append(
|
561 |
+
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
|
562 |
+
stride = 1
|
563 |
+
self.padding = kernel_size // 2
|
564 |
+
# conv
|
565 |
+
layers.append(
|
566 |
+
EqualConv2d(
|
567 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
568 |
+
and not activate))
|
569 |
+
# activation
|
570 |
+
if activate:
|
571 |
+
if bias:
|
572 |
+
layers.append(FusedLeakyReLU(out_channels))
|
573 |
+
else:
|
574 |
+
layers.append(ScaledLeakyReLU(0.2))
|
575 |
+
|
576 |
+
super(ConvLayer, self).__init__(*layers)
|
577 |
+
|
578 |
+
|
579 |
+
class ResBlock(nn.Module):
|
580 |
+
"""Residual block used in StyleGAN2 Discriminator.
|
581 |
+
|
582 |
+
Args:
|
583 |
+
in_channels (int): Channel number of the input.
|
584 |
+
out_channels (int): Channel number of the output.
|
585 |
+
"""
|
586 |
+
|
587 |
+
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
|
588 |
+
super(ResBlock, self).__init__()
|
589 |
+
|
590 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
591 |
+
self.conv2 = ConvLayer(
|
592 |
+
in_channels,
|
593 |
+
out_channels,
|
594 |
+
3,
|
595 |
+
downsample=True,
|
596 |
+
interpolation_mode=interpolation_mode,
|
597 |
+
bias=True,
|
598 |
+
activate=True)
|
599 |
+
self.skip = ConvLayer(
|
600 |
+
in_channels,
|
601 |
+
out_channels,
|
602 |
+
1,
|
603 |
+
downsample=True,
|
604 |
+
interpolation_mode=interpolation_mode,
|
605 |
+
bias=False,
|
606 |
+
activate=False)
|
607 |
+
|
608 |
+
def forward(self, x):
|
609 |
+
out = self.conv1(x)
|
610 |
+
out = self.conv2(out)
|
611 |
+
skip = self.skip(x)
|
612 |
+
out = (out + skip) / math.sqrt(2)
|
613 |
+
return out
|
gfpgan/archs/stylegan2_clean_arch.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.archs.arch_util import default_init_weights
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class NormStyleCode(nn.Module):
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
"""Normalize the style codes.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (Tensor): Style codes with shape (b, c).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
Tensor: Normalized tensor.
|
20 |
+
"""
|
21 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
22 |
+
|
23 |
+
|
24 |
+
class ModulatedConv2d(nn.Module):
|
25 |
+
"""Modulated Conv2d used in StyleGAN2.
|
26 |
+
|
27 |
+
There is no bias in ModulatedConv2d.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
in_channels (int): Channel number of the input.
|
31 |
+
out_channels (int): Channel number of the output.
|
32 |
+
kernel_size (int): Size of the convolving kernel.
|
33 |
+
num_style_feat (int): Channel number of style features.
|
34 |
+
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
35 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
36 |
+
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
in_channels,
|
41 |
+
out_channels,
|
42 |
+
kernel_size,
|
43 |
+
num_style_feat,
|
44 |
+
demodulate=True,
|
45 |
+
sample_mode=None,
|
46 |
+
eps=1e-8):
|
47 |
+
super(ModulatedConv2d, self).__init__()
|
48 |
+
self.in_channels = in_channels
|
49 |
+
self.out_channels = out_channels
|
50 |
+
self.kernel_size = kernel_size
|
51 |
+
self.demodulate = demodulate
|
52 |
+
self.sample_mode = sample_mode
|
53 |
+
self.eps = eps
|
54 |
+
|
55 |
+
# modulation inside each modulated conv
|
56 |
+
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
57 |
+
# initialization
|
58 |
+
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
59 |
+
|
60 |
+
self.weight = nn.Parameter(
|
61 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
62 |
+
math.sqrt(in_channels * kernel_size**2))
|
63 |
+
self.padding = kernel_size // 2
|
64 |
+
|
65 |
+
def forward(self, x, style):
|
66 |
+
"""Forward function.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
70 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Tensor: Modulated tensor after convolution.
|
74 |
+
"""
|
75 |
+
b, c, h, w = x.shape # c = c_in
|
76 |
+
# weight modulation
|
77 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
78 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
79 |
+
weight = self.weight * style # (b, c_out, c_in, k, k)
|
80 |
+
|
81 |
+
if self.demodulate:
|
82 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
83 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
84 |
+
|
85 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
86 |
+
|
87 |
+
# upsample or downsample if necessary
|
88 |
+
if self.sample_mode == 'upsample':
|
89 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
90 |
+
elif self.sample_mode == 'downsample':
|
91 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
92 |
+
|
93 |
+
b, c, h, w = x.shape
|
94 |
+
x = x.view(1, b * c, h, w)
|
95 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
96 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
97 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
98 |
+
|
99 |
+
return out
|
100 |
+
|
101 |
+
def __repr__(self):
|
102 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
103 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
104 |
+
|
105 |
+
|
106 |
+
class StyleConv(nn.Module):
|
107 |
+
"""Style conv used in StyleGAN2.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
in_channels (int): Channel number of the input.
|
111 |
+
out_channels (int): Channel number of the output.
|
112 |
+
kernel_size (int): Size of the convolving kernel.
|
113 |
+
num_style_feat (int): Channel number of style features.
|
114 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
115 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
119 |
+
super(StyleConv, self).__init__()
|
120 |
+
self.modulated_conv = ModulatedConv2d(
|
121 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
122 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
123 |
+
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
124 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
125 |
+
|
126 |
+
def forward(self, x, style, noise=None):
|
127 |
+
# modulate
|
128 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
129 |
+
# noise injection
|
130 |
+
if noise is None:
|
131 |
+
b, _, h, w = out.shape
|
132 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
133 |
+
out = out + self.weight * noise
|
134 |
+
# add bias
|
135 |
+
out = out + self.bias
|
136 |
+
# activation
|
137 |
+
out = self.activate(out)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
class ToRGB(nn.Module):
|
142 |
+
"""To RGB (image space) from features.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
in_channels (int): Channel number of input.
|
146 |
+
num_style_feat (int): Channel number of style features.
|
147 |
+
upsample (bool): Whether to upsample. Default: True.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
151 |
+
super(ToRGB, self).__init__()
|
152 |
+
self.upsample = upsample
|
153 |
+
self.modulated_conv = ModulatedConv2d(
|
154 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
155 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
156 |
+
|
157 |
+
def forward(self, x, style, skip=None):
|
158 |
+
"""Forward function.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
162 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
163 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: RGB images.
|
167 |
+
"""
|
168 |
+
out = self.modulated_conv(x, style)
|
169 |
+
out = out + self.bias
|
170 |
+
if skip is not None:
|
171 |
+
if self.upsample:
|
172 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
173 |
+
out = out + skip
|
174 |
+
return out
|
175 |
+
|
176 |
+
|
177 |
+
class ConstantInput(nn.Module):
|
178 |
+
"""Constant input.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
num_channel (int): Channel number of constant input.
|
182 |
+
size (int): Spatial size of constant input.
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, num_channel, size):
|
186 |
+
super(ConstantInput, self).__init__()
|
187 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
188 |
+
|
189 |
+
def forward(self, batch):
|
190 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
191 |
+
return out
|
192 |
+
|
193 |
+
|
194 |
+
@ARCH_REGISTRY.register()
|
195 |
+
class StyleGAN2GeneratorClean(nn.Module):
|
196 |
+
"""Clean version of StyleGAN2 Generator.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
out_size (int): The spatial size of outputs.
|
200 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
201 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
202 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
203 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
204 |
+
"""
|
205 |
+
|
206 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
207 |
+
super(StyleGAN2GeneratorClean, self).__init__()
|
208 |
+
# Style MLP layers
|
209 |
+
self.num_style_feat = num_style_feat
|
210 |
+
style_mlp_layers = [NormStyleCode()]
|
211 |
+
for i in range(num_mlp):
|
212 |
+
style_mlp_layers.extend(
|
213 |
+
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
214 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
215 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
216 |
+
# initialization
|
217 |
+
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
218 |
+
|
219 |
+
# channel list
|
220 |
+
channels = {
|
221 |
+
'4': int(512 * narrow),
|
222 |
+
'8': int(512 * narrow),
|
223 |
+
'16': int(512 * narrow),
|
224 |
+
'32': int(512 * narrow),
|
225 |
+
'64': int(256 * channel_multiplier * narrow),
|
226 |
+
'128': int(128 * channel_multiplier * narrow),
|
227 |
+
'256': int(64 * channel_multiplier * narrow),
|
228 |
+
'512': int(32 * channel_multiplier * narrow),
|
229 |
+
'1024': int(16 * channel_multiplier * narrow)
|
230 |
+
}
|
231 |
+
self.channels = channels
|
232 |
+
|
233 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
234 |
+
self.style_conv1 = StyleConv(
|
235 |
+
channels['4'],
|
236 |
+
channels['4'],
|
237 |
+
kernel_size=3,
|
238 |
+
num_style_feat=num_style_feat,
|
239 |
+
demodulate=True,
|
240 |
+
sample_mode=None)
|
241 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
242 |
+
|
243 |
+
self.log_size = int(math.log(out_size, 2))
|
244 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
245 |
+
self.num_latent = self.log_size * 2 - 2
|
246 |
+
|
247 |
+
self.style_convs = nn.ModuleList()
|
248 |
+
self.to_rgbs = nn.ModuleList()
|
249 |
+
self.noises = nn.Module()
|
250 |
+
|
251 |
+
in_channels = channels['4']
|
252 |
+
# noise
|
253 |
+
for layer_idx in range(self.num_layers):
|
254 |
+
resolution = 2**((layer_idx + 5) // 2)
|
255 |
+
shape = [1, 1, resolution, resolution]
|
256 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
257 |
+
# style convs and to_rgbs
|
258 |
+
for i in range(3, self.log_size + 1):
|
259 |
+
out_channels = channels[f'{2**i}']
|
260 |
+
self.style_convs.append(
|
261 |
+
StyleConv(
|
262 |
+
in_channels,
|
263 |
+
out_channels,
|
264 |
+
kernel_size=3,
|
265 |
+
num_style_feat=num_style_feat,
|
266 |
+
demodulate=True,
|
267 |
+
sample_mode='upsample'))
|
268 |
+
self.style_convs.append(
|
269 |
+
StyleConv(
|
270 |
+
out_channels,
|
271 |
+
out_channels,
|
272 |
+
kernel_size=3,
|
273 |
+
num_style_feat=num_style_feat,
|
274 |
+
demodulate=True,
|
275 |
+
sample_mode=None))
|
276 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
277 |
+
in_channels = out_channels
|
278 |
+
|
279 |
+
def make_noise(self):
|
280 |
+
"""Make noise for noise injection."""
|
281 |
+
device = self.constant_input.weight.device
|
282 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
283 |
+
|
284 |
+
for i in range(3, self.log_size + 1):
|
285 |
+
for _ in range(2):
|
286 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
287 |
+
|
288 |
+
return noises
|
289 |
+
|
290 |
+
def get_latent(self, x):
|
291 |
+
return self.style_mlp(x)
|
292 |
+
|
293 |
+
def mean_latent(self, num_latent):
|
294 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
295 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
296 |
+
return latent
|
297 |
+
|
298 |
+
def forward(self,
|
299 |
+
styles,
|
300 |
+
input_is_latent=False,
|
301 |
+
noise=None,
|
302 |
+
randomize_noise=True,
|
303 |
+
truncation=1,
|
304 |
+
truncation_latent=None,
|
305 |
+
inject_index=None,
|
306 |
+
return_latents=False):
|
307 |
+
"""Forward function for StyleGAN2GeneratorClean.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
styles (list[Tensor]): Sample codes of styles.
|
311 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
312 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
313 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
314 |
+
truncation (float): The truncation ratio. Default: 1.
|
315 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
316 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
317 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
318 |
+
"""
|
319 |
+
# style codes -> latents with Style MLP layer
|
320 |
+
if not input_is_latent:
|
321 |
+
styles = [self.style_mlp(s) for s in styles]
|
322 |
+
# noises
|
323 |
+
if noise is None:
|
324 |
+
if randomize_noise:
|
325 |
+
noise = [None] * self.num_layers # for each style conv layer
|
326 |
+
else: # use the stored noise
|
327 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
328 |
+
# style truncation
|
329 |
+
if truncation < 1:
|
330 |
+
style_truncation = []
|
331 |
+
for style in styles:
|
332 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
333 |
+
styles = style_truncation
|
334 |
+
# get style latents with injection
|
335 |
+
if len(styles) == 1:
|
336 |
+
inject_index = self.num_latent
|
337 |
+
|
338 |
+
if styles[0].ndim < 3:
|
339 |
+
# repeat latent code for all the layers
|
340 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
341 |
+
else: # used for encoder with different latent code for each layer
|
342 |
+
latent = styles[0]
|
343 |
+
elif len(styles) == 2: # mixing noises
|
344 |
+
if inject_index is None:
|
345 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
346 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
347 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
348 |
+
latent = torch.cat([latent1, latent2], 1)
|
349 |
+
|
350 |
+
# main generation
|
351 |
+
out = self.constant_input(latent.shape[0])
|
352 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
353 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
354 |
+
|
355 |
+
i = 1
|
356 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
357 |
+
noise[2::2], self.to_rgbs):
|
358 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
359 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
360 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
361 |
+
i += 2
|
362 |
+
|
363 |
+
image = skip
|
364 |
+
|
365 |
+
if return_latents:
|
366 |
+
return image, latent
|
367 |
+
else:
|
368 |
+
return image, None
|
gfpgan/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import dataset modules for registry
|
6 |
+
# scan all the files that end with '_dataset.py' under the data folder
|
7 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
9 |
+
# import all the dataset modules
|
10 |
+
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
|
gfpgan/data/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (893 Bytes). View file
|
|
gfpgan/data/__pycache__/ffhq_degradation_dataset.cpython-312.pyc
ADDED
Binary file (13.5 kB). View file
|
|
gfpgan/data/ffhq_degradation_dataset.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os.path as osp
|
5 |
+
import torch
|
6 |
+
import torch.utils.data as data
|
7 |
+
from basicsr.data import degradations as degradations
|
8 |
+
from basicsr.data.data_util import paths_from_folder
|
9 |
+
from basicsr.data.transforms import augment
|
10 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
11 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
12 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
13 |
+
normalize)
|
14 |
+
|
15 |
+
|
16 |
+
@DATASET_REGISTRY.register()
|
17 |
+
class FFHQDegradationDataset(data.Dataset):
|
18 |
+
"""FFHQ dataset for GFPGAN.
|
19 |
+
|
20 |
+
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
24 |
+
dataroot_gt (str): Data root path for gt.
|
25 |
+
io_backend (dict): IO backend type and other kwarg.
|
26 |
+
mean (list | tuple): Image mean.
|
27 |
+
std (list | tuple): Image std.
|
28 |
+
use_hflip (bool): Whether to horizontally flip.
|
29 |
+
Please see more options in the codes.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, opt):
|
33 |
+
super(FFHQDegradationDataset, self).__init__()
|
34 |
+
self.opt = opt
|
35 |
+
# file client (io backend)
|
36 |
+
self.file_client = None
|
37 |
+
self.io_backend_opt = opt['io_backend']
|
38 |
+
|
39 |
+
self.gt_folder = opt['dataroot_gt']
|
40 |
+
self.mean = opt['mean']
|
41 |
+
self.std = opt['std']
|
42 |
+
self.out_size = opt['out_size']
|
43 |
+
|
44 |
+
self.crop_components = opt.get('crop_components', False) # facial components
|
45 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
46 |
+
|
47 |
+
if self.crop_components:
|
48 |
+
# load component list from a pre-process pth files
|
49 |
+
self.components_list = torch.load(opt.get('component_path'))
|
50 |
+
|
51 |
+
# file client (lmdb io backend)
|
52 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
53 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
54 |
+
if not self.gt_folder.endswith('.lmdb'):
|
55 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
56 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
57 |
+
self.paths = [line.split('.')[0] for line in fin]
|
58 |
+
else:
|
59 |
+
# disk backend: scan file list from a folder
|
60 |
+
self.paths = paths_from_folder(self.gt_folder)
|
61 |
+
|
62 |
+
# degradation configurations
|
63 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
64 |
+
self.kernel_list = opt['kernel_list']
|
65 |
+
self.kernel_prob = opt['kernel_prob']
|
66 |
+
self.blur_sigma = opt['blur_sigma']
|
67 |
+
self.downsample_range = opt['downsample_range']
|
68 |
+
self.noise_range = opt['noise_range']
|
69 |
+
self.jpeg_range = opt['jpeg_range']
|
70 |
+
|
71 |
+
# color jitter
|
72 |
+
self.color_jitter_prob = opt.get('color_jitter_prob')
|
73 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
74 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
75 |
+
# to gray
|
76 |
+
self.gray_prob = opt.get('gray_prob')
|
77 |
+
|
78 |
+
logger = get_root_logger()
|
79 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
80 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
81 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
82 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
83 |
+
|
84 |
+
if self.color_jitter_prob is not None:
|
85 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
86 |
+
if self.gray_prob is not None:
|
87 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
88 |
+
self.color_jitter_shift /= 255.
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def color_jitter(img, shift):
|
92 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
93 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
94 |
+
img = img + jitter_val
|
95 |
+
img = np.clip(img, 0, 1)
|
96 |
+
return img
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
100 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
101 |
+
fn_idx = torch.randperm(4)
|
102 |
+
for fn_id in fn_idx:
|
103 |
+
if fn_id == 0 and brightness is not None:
|
104 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
105 |
+
img = adjust_brightness(img, brightness_factor)
|
106 |
+
|
107 |
+
if fn_id == 1 and contrast is not None:
|
108 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
109 |
+
img = adjust_contrast(img, contrast_factor)
|
110 |
+
|
111 |
+
if fn_id == 2 and saturation is not None:
|
112 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
113 |
+
img = adjust_saturation(img, saturation_factor)
|
114 |
+
|
115 |
+
if fn_id == 3 and hue is not None:
|
116 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
117 |
+
img = adjust_hue(img, hue_factor)
|
118 |
+
return img
|
119 |
+
|
120 |
+
def get_component_coordinates(self, index, status):
|
121 |
+
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
122 |
+
components_bbox = self.components_list[f'{index:08d}']
|
123 |
+
if status[0]: # hflip
|
124 |
+
# exchange right and left eye
|
125 |
+
tmp = components_bbox['left_eye']
|
126 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
127 |
+
components_bbox['right_eye'] = tmp
|
128 |
+
# modify the width coordinate
|
129 |
+
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
130 |
+
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
131 |
+
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
132 |
+
|
133 |
+
# get coordinates
|
134 |
+
locations = []
|
135 |
+
for part in ['left_eye', 'right_eye', 'mouth']:
|
136 |
+
mean = components_bbox[part][0:2]
|
137 |
+
half_len = components_bbox[part][2]
|
138 |
+
if 'eye' in part:
|
139 |
+
half_len *= self.eye_enlarge_ratio
|
140 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
141 |
+
loc = torch.from_numpy(loc).float()
|
142 |
+
locations.append(loc)
|
143 |
+
return locations
|
144 |
+
|
145 |
+
def __getitem__(self, index):
|
146 |
+
if self.file_client is None:
|
147 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
148 |
+
|
149 |
+
# load gt image
|
150 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
151 |
+
gt_path = self.paths[index]
|
152 |
+
img_bytes = self.file_client.get(gt_path)
|
153 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
154 |
+
|
155 |
+
# random horizontal flip
|
156 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
157 |
+
h, w, _ = img_gt.shape
|
158 |
+
|
159 |
+
# get facial component coordinates
|
160 |
+
if self.crop_components:
|
161 |
+
locations = self.get_component_coordinates(index, status)
|
162 |
+
loc_left_eye, loc_right_eye, loc_mouth = locations
|
163 |
+
|
164 |
+
# ------------------------ generate lq image ------------------------ #
|
165 |
+
# blur
|
166 |
+
kernel = degradations.random_mixed_kernels(
|
167 |
+
self.kernel_list,
|
168 |
+
self.kernel_prob,
|
169 |
+
self.blur_kernel_size,
|
170 |
+
self.blur_sigma,
|
171 |
+
self.blur_sigma, [-math.pi, math.pi],
|
172 |
+
noise_range=None)
|
173 |
+
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
174 |
+
# downsample
|
175 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
176 |
+
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
177 |
+
# noise
|
178 |
+
if self.noise_range is not None:
|
179 |
+
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
180 |
+
# jpeg compression
|
181 |
+
if self.jpeg_range is not None:
|
182 |
+
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
183 |
+
|
184 |
+
# resize to original size
|
185 |
+
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
186 |
+
|
187 |
+
# random color jitter (only for lq)
|
188 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
189 |
+
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
190 |
+
# random to gray (only for lq)
|
191 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
192 |
+
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
193 |
+
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
194 |
+
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
195 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
196 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
197 |
+
|
198 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
199 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
200 |
+
|
201 |
+
# random color jitter (pytorch version) (only for lq)
|
202 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
203 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
204 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
205 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
206 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
207 |
+
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
208 |
+
|
209 |
+
# round and clip
|
210 |
+
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
211 |
+
|
212 |
+
# normalize
|
213 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
214 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
215 |
+
|
216 |
+
if self.crop_components:
|
217 |
+
return_dict = {
|
218 |
+
'lq': img_lq,
|
219 |
+
'gt': img_gt,
|
220 |
+
'gt_path': gt_path,
|
221 |
+
'loc_left_eye': loc_left_eye,
|
222 |
+
'loc_right_eye': loc_right_eye,
|
223 |
+
'loc_mouth': loc_mouth
|
224 |
+
}
|
225 |
+
return return_dict
|
226 |
+
else:
|
227 |
+
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
228 |
+
|
229 |
+
def __len__(self):
|
230 |
+
return len(self.paths)
|
gfpgan/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import model modules for registry
|
6 |
+
# scan all the files that end with '_model.py' under the model folder
|
7 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
9 |
+
# import all the model modules
|
10 |
+
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
|
gfpgan/models/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (890 Bytes). View file
|
|
gfpgan/models/__pycache__/gfpgan_model.cpython-312.pyc
ADDED
Binary file (31.5 kB). View file
|
|
gfpgan/models/gfpgan_model.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
from basicsr.archs import build_network
|
5 |
+
from basicsr.losses import build_loss
|
6 |
+
from basicsr.losses.gan_loss import r1_penalty
|
7 |
+
from basicsr.metrics import calculate_metric
|
8 |
+
from basicsr.models.base_model import BaseModel
|
9 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
10 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
11 |
+
from collections import OrderedDict
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from torchvision.ops import roi_align
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
@MODEL_REGISTRY.register()
|
18 |
+
class GFPGANModel(BaseModel):
|
19 |
+
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
20 |
+
|
21 |
+
def __init__(self, opt):
|
22 |
+
super(GFPGANModel, self).__init__(opt)
|
23 |
+
self.idx = 0 # it is used for saving data for check
|
24 |
+
|
25 |
+
# define network
|
26 |
+
self.net_g = build_network(opt['network_g'])
|
27 |
+
self.net_g = self.model_to_device(self.net_g)
|
28 |
+
self.print_network(self.net_g)
|
29 |
+
|
30 |
+
# load pretrained model
|
31 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
32 |
+
if load_path is not None:
|
33 |
+
param_key = self.opt['path'].get('param_key_g', 'params')
|
34 |
+
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
35 |
+
|
36 |
+
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
|
37 |
+
|
38 |
+
if self.is_train:
|
39 |
+
self.init_training_settings()
|
40 |
+
|
41 |
+
def init_training_settings(self):
|
42 |
+
train_opt = self.opt['train']
|
43 |
+
|
44 |
+
# ----------- define net_d ----------- #
|
45 |
+
self.net_d = build_network(self.opt['network_d'])
|
46 |
+
self.net_d = self.model_to_device(self.net_d)
|
47 |
+
self.print_network(self.net_d)
|
48 |
+
# load pretrained model
|
49 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
50 |
+
if load_path is not None:
|
51 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
52 |
+
|
53 |
+
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
54 |
+
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
55 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
56 |
+
# load pretrained model
|
57 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
58 |
+
if load_path is not None:
|
59 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
60 |
+
else:
|
61 |
+
self.model_ema(0) # copy net_g weight
|
62 |
+
|
63 |
+
self.net_g.train()
|
64 |
+
self.net_d.train()
|
65 |
+
self.net_g_ema.eval()
|
66 |
+
|
67 |
+
# ----------- facial component networks ----------- #
|
68 |
+
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
69 |
+
self.use_facial_disc = True
|
70 |
+
else:
|
71 |
+
self.use_facial_disc = False
|
72 |
+
|
73 |
+
if self.use_facial_disc:
|
74 |
+
# left eye
|
75 |
+
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
|
76 |
+
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
|
77 |
+
self.print_network(self.net_d_left_eye)
|
78 |
+
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
|
79 |
+
if load_path is not None:
|
80 |
+
self.load_network(self.net_d_left_eye, load_path, True, 'params')
|
81 |
+
# right eye
|
82 |
+
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
|
83 |
+
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
|
84 |
+
self.print_network(self.net_d_right_eye)
|
85 |
+
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
|
86 |
+
if load_path is not None:
|
87 |
+
self.load_network(self.net_d_right_eye, load_path, True, 'params')
|
88 |
+
# mouth
|
89 |
+
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
|
90 |
+
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
|
91 |
+
self.print_network(self.net_d_mouth)
|
92 |
+
load_path = self.opt['path'].get('pretrain_network_d_mouth')
|
93 |
+
if load_path is not None:
|
94 |
+
self.load_network(self.net_d_mouth, load_path, True, 'params')
|
95 |
+
|
96 |
+
self.net_d_left_eye.train()
|
97 |
+
self.net_d_right_eye.train()
|
98 |
+
self.net_d_mouth.train()
|
99 |
+
|
100 |
+
# ----------- define facial component gan loss ----------- #
|
101 |
+
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
102 |
+
|
103 |
+
# ----------- define losses ----------- #
|
104 |
+
# pixel loss
|
105 |
+
if train_opt.get('pixel_opt'):
|
106 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
107 |
+
else:
|
108 |
+
self.cri_pix = None
|
109 |
+
|
110 |
+
# perceptual loss
|
111 |
+
if train_opt.get('perceptual_opt'):
|
112 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
113 |
+
else:
|
114 |
+
self.cri_perceptual = None
|
115 |
+
|
116 |
+
# L1 loss is used in pyramid loss, component style loss and identity loss
|
117 |
+
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
118 |
+
|
119 |
+
# gan loss (wgan)
|
120 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
121 |
+
|
122 |
+
# ----------- define identity loss ----------- #
|
123 |
+
if 'network_identity' in self.opt:
|
124 |
+
self.use_identity = True
|
125 |
+
else:
|
126 |
+
self.use_identity = False
|
127 |
+
|
128 |
+
if self.use_identity:
|
129 |
+
# define identity network
|
130 |
+
self.network_identity = build_network(self.opt['network_identity'])
|
131 |
+
self.network_identity = self.model_to_device(self.network_identity)
|
132 |
+
self.print_network(self.network_identity)
|
133 |
+
load_path = self.opt['path'].get('pretrain_network_identity')
|
134 |
+
if load_path is not None:
|
135 |
+
self.load_network(self.network_identity, load_path, True, None)
|
136 |
+
self.network_identity.eval()
|
137 |
+
for param in self.network_identity.parameters():
|
138 |
+
param.requires_grad = False
|
139 |
+
|
140 |
+
# regularization weights
|
141 |
+
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
|
142 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
143 |
+
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
|
144 |
+
self.net_d_reg_every = train_opt['net_d_reg_every']
|
145 |
+
|
146 |
+
# set up optimizers and schedulers
|
147 |
+
self.setup_optimizers()
|
148 |
+
self.setup_schedulers()
|
149 |
+
|
150 |
+
def setup_optimizers(self):
|
151 |
+
train_opt = self.opt['train']
|
152 |
+
|
153 |
+
# ----------- optimizer g ----------- #
|
154 |
+
net_g_reg_ratio = 1
|
155 |
+
normal_params = []
|
156 |
+
for _, param in self.net_g.named_parameters():
|
157 |
+
normal_params.append(param)
|
158 |
+
optim_params_g = [{ # add normal params first
|
159 |
+
'params': normal_params,
|
160 |
+
'lr': train_opt['optim_g']['lr']
|
161 |
+
}]
|
162 |
+
optim_type = train_opt['optim_g'].pop('type')
|
163 |
+
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
|
164 |
+
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
|
165 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
|
166 |
+
self.optimizers.append(self.optimizer_g)
|
167 |
+
|
168 |
+
# ----------- optimizer d ----------- #
|
169 |
+
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
|
170 |
+
normal_params = []
|
171 |
+
for _, param in self.net_d.named_parameters():
|
172 |
+
normal_params.append(param)
|
173 |
+
optim_params_d = [{ # add normal params first
|
174 |
+
'params': normal_params,
|
175 |
+
'lr': train_opt['optim_d']['lr']
|
176 |
+
}]
|
177 |
+
optim_type = train_opt['optim_d'].pop('type')
|
178 |
+
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
|
179 |
+
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
|
180 |
+
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
181 |
+
self.optimizers.append(self.optimizer_d)
|
182 |
+
|
183 |
+
# ----------- optimizers for facial component networks ----------- #
|
184 |
+
if self.use_facial_disc:
|
185 |
+
# setup optimizers for facial component discriminators
|
186 |
+
optim_type = train_opt['optim_component'].pop('type')
|
187 |
+
lr = train_opt['optim_component']['lr']
|
188 |
+
# left eye
|
189 |
+
self.optimizer_d_left_eye = self.get_optimizer(
|
190 |
+
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
|
191 |
+
self.optimizers.append(self.optimizer_d_left_eye)
|
192 |
+
# right eye
|
193 |
+
self.optimizer_d_right_eye = self.get_optimizer(
|
194 |
+
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
|
195 |
+
self.optimizers.append(self.optimizer_d_right_eye)
|
196 |
+
# mouth
|
197 |
+
self.optimizer_d_mouth = self.get_optimizer(
|
198 |
+
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
|
199 |
+
self.optimizers.append(self.optimizer_d_mouth)
|
200 |
+
|
201 |
+
def feed_data(self, data):
|
202 |
+
self.lq = data['lq'].to(self.device)
|
203 |
+
if 'gt' in data:
|
204 |
+
self.gt = data['gt'].to(self.device)
|
205 |
+
|
206 |
+
if 'loc_left_eye' in data:
|
207 |
+
# get facial component locations, shape (batch, 4)
|
208 |
+
self.loc_left_eyes = data['loc_left_eye']
|
209 |
+
self.loc_right_eyes = data['loc_right_eye']
|
210 |
+
self.loc_mouths = data['loc_mouth']
|
211 |
+
|
212 |
+
# uncomment to check data
|
213 |
+
# import torchvision
|
214 |
+
# if self.opt['rank'] == 0:
|
215 |
+
# import os
|
216 |
+
# os.makedirs('tmp/gt', exist_ok=True)
|
217 |
+
# os.makedirs('tmp/lq', exist_ok=True)
|
218 |
+
# print(self.idx)
|
219 |
+
# torchvision.utils.save_image(
|
220 |
+
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
221 |
+
# torchvision.utils.save_image(
|
222 |
+
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
223 |
+
# self.idx = self.idx + 1
|
224 |
+
|
225 |
+
def construct_img_pyramid(self):
|
226 |
+
"""Construct image pyramid for intermediate restoration loss"""
|
227 |
+
pyramid_gt = [self.gt]
|
228 |
+
down_img = self.gt
|
229 |
+
for _ in range(0, self.log_size - 3):
|
230 |
+
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
|
231 |
+
pyramid_gt.insert(0, down_img)
|
232 |
+
return pyramid_gt
|
233 |
+
|
234 |
+
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
235 |
+
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
236 |
+
eye_out_size *= face_ratio
|
237 |
+
mouth_out_size *= face_ratio
|
238 |
+
|
239 |
+
rois_eyes = []
|
240 |
+
rois_mouths = []
|
241 |
+
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
|
242 |
+
# left eye and right eye
|
243 |
+
img_inds = self.loc_left_eyes.new_full((2, 1), b)
|
244 |
+
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
|
245 |
+
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
|
246 |
+
rois_eyes.append(rois)
|
247 |
+
# mouse
|
248 |
+
img_inds = self.loc_left_eyes.new_full((1, 1), b)
|
249 |
+
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
|
250 |
+
rois_mouths.append(rois)
|
251 |
+
|
252 |
+
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
|
253 |
+
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
|
254 |
+
|
255 |
+
# real images
|
256 |
+
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
257 |
+
self.left_eyes_gt = all_eyes[0::2, :, :, :]
|
258 |
+
self.right_eyes_gt = all_eyes[1::2, :, :, :]
|
259 |
+
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
260 |
+
# output
|
261 |
+
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
262 |
+
self.left_eyes = all_eyes[0::2, :, :, :]
|
263 |
+
self.right_eyes = all_eyes[1::2, :, :, :]
|
264 |
+
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
265 |
+
|
266 |
+
def _gram_mat(self, x):
|
267 |
+
"""Calculate Gram matrix.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
torch.Tensor: Gram matrix.
|
274 |
+
"""
|
275 |
+
n, c, h, w = x.size()
|
276 |
+
features = x.view(n, c, w * h)
|
277 |
+
features_t = features.transpose(1, 2)
|
278 |
+
gram = features.bmm(features_t) / (c * h * w)
|
279 |
+
return gram
|
280 |
+
|
281 |
+
def gray_resize_for_identity(self, out, size=128):
|
282 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
283 |
+
out_gray = out_gray.unsqueeze(1)
|
284 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
285 |
+
return out_gray
|
286 |
+
|
287 |
+
def optimize_parameters(self, current_iter):
|
288 |
+
# optimize net_g
|
289 |
+
for p in self.net_d.parameters():
|
290 |
+
p.requires_grad = False
|
291 |
+
self.optimizer_g.zero_grad()
|
292 |
+
|
293 |
+
# do not update facial component net_d
|
294 |
+
if self.use_facial_disc:
|
295 |
+
for p in self.net_d_left_eye.parameters():
|
296 |
+
p.requires_grad = False
|
297 |
+
for p in self.net_d_right_eye.parameters():
|
298 |
+
p.requires_grad = False
|
299 |
+
for p in self.net_d_mouth.parameters():
|
300 |
+
p.requires_grad = False
|
301 |
+
|
302 |
+
# image pyramid loss weight
|
303 |
+
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
|
304 |
+
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
305 |
+
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
|
306 |
+
if pyramid_loss_weight > 0:
|
307 |
+
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
308 |
+
pyramid_gt = self.construct_img_pyramid()
|
309 |
+
else:
|
310 |
+
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
|
311 |
+
|
312 |
+
# get roi-align regions
|
313 |
+
if self.use_facial_disc:
|
314 |
+
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
|
315 |
+
|
316 |
+
l_g_total = 0
|
317 |
+
loss_dict = OrderedDict()
|
318 |
+
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
319 |
+
# pixel loss
|
320 |
+
if self.cri_pix:
|
321 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
322 |
+
l_g_total += l_g_pix
|
323 |
+
loss_dict['l_g_pix'] = l_g_pix
|
324 |
+
|
325 |
+
# image pyramid loss
|
326 |
+
if pyramid_loss_weight > 0:
|
327 |
+
for i in range(0, self.log_size - 2):
|
328 |
+
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
|
329 |
+
l_g_total += l_pyramid
|
330 |
+
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
|
331 |
+
|
332 |
+
# perceptual loss
|
333 |
+
if self.cri_perceptual:
|
334 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
|
335 |
+
if l_g_percep is not None:
|
336 |
+
l_g_total += l_g_percep
|
337 |
+
loss_dict['l_g_percep'] = l_g_percep
|
338 |
+
if l_g_style is not None:
|
339 |
+
l_g_total += l_g_style
|
340 |
+
loss_dict['l_g_style'] = l_g_style
|
341 |
+
|
342 |
+
# gan loss
|
343 |
+
fake_g_pred = self.net_d(self.output)
|
344 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
345 |
+
l_g_total += l_g_gan
|
346 |
+
loss_dict['l_g_gan'] = l_g_gan
|
347 |
+
|
348 |
+
# facial component loss
|
349 |
+
if self.use_facial_disc:
|
350 |
+
# left eye
|
351 |
+
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
|
352 |
+
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
|
353 |
+
l_g_total += l_g_gan
|
354 |
+
loss_dict['l_g_gan_left_eye'] = l_g_gan
|
355 |
+
# right eye
|
356 |
+
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
|
357 |
+
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
|
358 |
+
l_g_total += l_g_gan
|
359 |
+
loss_dict['l_g_gan_right_eye'] = l_g_gan
|
360 |
+
# mouth
|
361 |
+
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
|
362 |
+
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
|
363 |
+
l_g_total += l_g_gan
|
364 |
+
loss_dict['l_g_gan_mouth'] = l_g_gan
|
365 |
+
|
366 |
+
if self.opt['train'].get('comp_style_weight', 0) > 0:
|
367 |
+
# get gt feat
|
368 |
+
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
|
369 |
+
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
|
370 |
+
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
|
371 |
+
|
372 |
+
def _comp_style(feat, feat_gt, criterion):
|
373 |
+
return criterion(self._gram_mat(feat[0]), self._gram_mat(
|
374 |
+
feat_gt[0].detach())) * 0.5 + criterion(
|
375 |
+
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
|
376 |
+
|
377 |
+
# facial component style loss
|
378 |
+
comp_style_loss = 0
|
379 |
+
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
|
380 |
+
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
|
381 |
+
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
|
382 |
+
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
|
383 |
+
l_g_total += comp_style_loss
|
384 |
+
loss_dict['l_g_comp_style_loss'] = comp_style_loss
|
385 |
+
|
386 |
+
# identity loss
|
387 |
+
if self.use_identity:
|
388 |
+
identity_weight = self.opt['train']['identity_weight']
|
389 |
+
# get gray images and resize
|
390 |
+
out_gray = self.gray_resize_for_identity(self.output)
|
391 |
+
gt_gray = self.gray_resize_for_identity(self.gt)
|
392 |
+
|
393 |
+
identity_gt = self.network_identity(gt_gray).detach()
|
394 |
+
identity_out = self.network_identity(out_gray)
|
395 |
+
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
|
396 |
+
l_g_total += l_identity
|
397 |
+
loss_dict['l_identity'] = l_identity
|
398 |
+
|
399 |
+
l_g_total.backward()
|
400 |
+
self.optimizer_g.step()
|
401 |
+
|
402 |
+
# EMA
|
403 |
+
self.model_ema(decay=0.5**(32 / (10 * 1000)))
|
404 |
+
|
405 |
+
# ----------- optimize net_d ----------- #
|
406 |
+
for p in self.net_d.parameters():
|
407 |
+
p.requires_grad = True
|
408 |
+
self.optimizer_d.zero_grad()
|
409 |
+
if self.use_facial_disc:
|
410 |
+
for p in self.net_d_left_eye.parameters():
|
411 |
+
p.requires_grad = True
|
412 |
+
for p in self.net_d_right_eye.parameters():
|
413 |
+
p.requires_grad = True
|
414 |
+
for p in self.net_d_mouth.parameters():
|
415 |
+
p.requires_grad = True
|
416 |
+
self.optimizer_d_left_eye.zero_grad()
|
417 |
+
self.optimizer_d_right_eye.zero_grad()
|
418 |
+
self.optimizer_d_mouth.zero_grad()
|
419 |
+
|
420 |
+
fake_d_pred = self.net_d(self.output.detach())
|
421 |
+
real_d_pred = self.net_d(self.gt)
|
422 |
+
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
423 |
+
loss_dict['l_d'] = l_d
|
424 |
+
# In WGAN, real_score should be positive and fake_score should be negative
|
425 |
+
loss_dict['real_score'] = real_d_pred.detach().mean()
|
426 |
+
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
427 |
+
l_d.backward()
|
428 |
+
|
429 |
+
# regularization loss
|
430 |
+
if current_iter % self.net_d_reg_every == 0:
|
431 |
+
self.gt.requires_grad = True
|
432 |
+
real_pred = self.net_d(self.gt)
|
433 |
+
l_d_r1 = r1_penalty(real_pred, self.gt)
|
434 |
+
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
|
435 |
+
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
|
436 |
+
l_d_r1.backward()
|
437 |
+
|
438 |
+
self.optimizer_d.step()
|
439 |
+
|
440 |
+
# optimize facial component discriminators
|
441 |
+
if self.use_facial_disc:
|
442 |
+
# left eye
|
443 |
+
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
444 |
+
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
445 |
+
l_d_left_eye = self.cri_component(
|
446 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
447 |
+
fake_d_pred, False, is_disc=True)
|
448 |
+
loss_dict['l_d_left_eye'] = l_d_left_eye
|
449 |
+
l_d_left_eye.backward()
|
450 |
+
# right eye
|
451 |
+
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
|
452 |
+
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
|
453 |
+
l_d_right_eye = self.cri_component(
|
454 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
455 |
+
fake_d_pred, False, is_disc=True)
|
456 |
+
loss_dict['l_d_right_eye'] = l_d_right_eye
|
457 |
+
l_d_right_eye.backward()
|
458 |
+
# mouth
|
459 |
+
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
|
460 |
+
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
|
461 |
+
l_d_mouth = self.cri_component(
|
462 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
463 |
+
fake_d_pred, False, is_disc=True)
|
464 |
+
loss_dict['l_d_mouth'] = l_d_mouth
|
465 |
+
l_d_mouth.backward()
|
466 |
+
|
467 |
+
self.optimizer_d_left_eye.step()
|
468 |
+
self.optimizer_d_right_eye.step()
|
469 |
+
self.optimizer_d_mouth.step()
|
470 |
+
|
471 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
472 |
+
|
473 |
+
def test(self):
|
474 |
+
with torch.no_grad():
|
475 |
+
if hasattr(self, 'net_g_ema'):
|
476 |
+
self.net_g_ema.eval()
|
477 |
+
self.output, _ = self.net_g_ema(self.lq)
|
478 |
+
else:
|
479 |
+
logger = get_root_logger()
|
480 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
481 |
+
self.net_g.eval()
|
482 |
+
self.output, _ = self.net_g(self.lq)
|
483 |
+
self.net_g.train()
|
484 |
+
|
485 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
486 |
+
if self.opt['rank'] == 0:
|
487 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
488 |
+
|
489 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
490 |
+
dataset_name = dataloader.dataset.opt['name']
|
491 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
492 |
+
use_pbar = self.opt['val'].get('pbar', False)
|
493 |
+
|
494 |
+
if with_metrics:
|
495 |
+
if not hasattr(self, 'metric_results'): # only execute in the first run
|
496 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
497 |
+
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
498 |
+
self._initialize_best_metric_results(dataset_name)
|
499 |
+
# zero self.metric_results
|
500 |
+
self.metric_results = {metric: 0 for metric in self.metric_results}
|
501 |
+
|
502 |
+
metric_data = dict()
|
503 |
+
if use_pbar:
|
504 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
505 |
+
|
506 |
+
for idx, val_data in enumerate(dataloader):
|
507 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
508 |
+
self.feed_data(val_data)
|
509 |
+
self.test()
|
510 |
+
|
511 |
+
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
512 |
+
metric_data['img'] = sr_img
|
513 |
+
if hasattr(self, 'gt'):
|
514 |
+
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
515 |
+
metric_data['img2'] = gt_img
|
516 |
+
del self.gt
|
517 |
+
|
518 |
+
# tentative for out of GPU memory
|
519 |
+
del self.lq
|
520 |
+
del self.output
|
521 |
+
torch.cuda.empty_cache()
|
522 |
+
|
523 |
+
if save_img:
|
524 |
+
if self.opt['is_train']:
|
525 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
526 |
+
f'{img_name}_{current_iter}.png')
|
527 |
+
else:
|
528 |
+
if self.opt['val']['suffix']:
|
529 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
530 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
531 |
+
else:
|
532 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
533 |
+
f'{img_name}_{self.opt["name"]}.png')
|
534 |
+
imwrite(sr_img, save_img_path)
|
535 |
+
|
536 |
+
if with_metrics:
|
537 |
+
# calculate metrics
|
538 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
539 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
540 |
+
if use_pbar:
|
541 |
+
pbar.update(1)
|
542 |
+
pbar.set_description(f'Test {img_name}')
|
543 |
+
if use_pbar:
|
544 |
+
pbar.close()
|
545 |
+
|
546 |
+
if with_metrics:
|
547 |
+
for metric in self.metric_results.keys():
|
548 |
+
self.metric_results[metric] /= (idx + 1)
|
549 |
+
# update the best metric result
|
550 |
+
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
551 |
+
|
552 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
553 |
+
|
554 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
555 |
+
log_str = f'Validation {dataset_name}\n'
|
556 |
+
for metric, value in self.metric_results.items():
|
557 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
558 |
+
if hasattr(self, 'best_metric_results'):
|
559 |
+
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
560 |
+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
561 |
+
log_str += '\n'
|
562 |
+
|
563 |
+
logger = get_root_logger()
|
564 |
+
logger.info(log_str)
|
565 |
+
if tb_logger:
|
566 |
+
for metric, value in self.metric_results.items():
|
567 |
+
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
568 |
+
|
569 |
+
def save(self, epoch, current_iter):
|
570 |
+
# save net_g and net_d
|
571 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
572 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
573 |
+
# save component discriminators
|
574 |
+
if self.use_facial_disc:
|
575 |
+
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
576 |
+
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
577 |
+
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
578 |
+
# save training state
|
579 |
+
self.save_training_state(epoch, current_iter)
|
gfpgan/train.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import os.path as osp
|
3 |
+
from basicsr.train import train_pipeline
|
4 |
+
|
5 |
+
import gfpgan.archs
|
6 |
+
import gfpgan.data
|
7 |
+
import gfpgan.models
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
11 |
+
train_pipeline(root_path)
|
gfpgan/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from basicsr.utils import img2tensor, tensor2img
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
7 |
+
from torchvision.transforms.functional import normalize
|
8 |
+
|
9 |
+
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
10 |
+
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
11 |
+
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
12 |
+
|
13 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
14 |
+
|
15 |
+
|
16 |
+
class GFPGANer():
|
17 |
+
"""Helper for restoration with GFPGAN.
|
18 |
+
|
19 |
+
It will detect and crop faces, and then resize the faces to 512x512.
|
20 |
+
GFPGAN is used to restored the resized faces.
|
21 |
+
The background is upsampled with the bg_upsampler.
|
22 |
+
Finally, the faces will be pasted back to the upsample background image.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
26 |
+
upscale (float): The upscale of the final output. Default: 2.
|
27 |
+
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
28 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
29 |
+
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
33 |
+
self.upscale = upscale
|
34 |
+
self.bg_upsampler = bg_upsampler
|
35 |
+
|
36 |
+
# initialize model
|
37 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
38 |
+
# initialize the GFP-GAN
|
39 |
+
if arch == 'clean':
|
40 |
+
self.gfpgan = GFPGANv1Clean(
|
41 |
+
out_size=512,
|
42 |
+
num_style_feat=512,
|
43 |
+
channel_multiplier=channel_multiplier,
|
44 |
+
decoder_load_path=None,
|
45 |
+
fix_decoder=False,
|
46 |
+
num_mlp=8,
|
47 |
+
input_is_latent=True,
|
48 |
+
different_w=True,
|
49 |
+
narrow=1,
|
50 |
+
sft_half=True)
|
51 |
+
elif arch == 'bilinear':
|
52 |
+
self.gfpgan = GFPGANBilinear(
|
53 |
+
out_size=512,
|
54 |
+
num_style_feat=512,
|
55 |
+
channel_multiplier=channel_multiplier,
|
56 |
+
decoder_load_path=None,
|
57 |
+
fix_decoder=False,
|
58 |
+
num_mlp=8,
|
59 |
+
input_is_latent=True,
|
60 |
+
different_w=True,
|
61 |
+
narrow=1,
|
62 |
+
sft_half=True)
|
63 |
+
elif arch == 'original':
|
64 |
+
self.gfpgan = GFPGANv1(
|
65 |
+
out_size=512,
|
66 |
+
num_style_feat=512,
|
67 |
+
channel_multiplier=channel_multiplier,
|
68 |
+
decoder_load_path=None,
|
69 |
+
fix_decoder=True,
|
70 |
+
num_mlp=8,
|
71 |
+
input_is_latent=True,
|
72 |
+
different_w=True,
|
73 |
+
narrow=1,
|
74 |
+
sft_half=True)
|
75 |
+
elif arch == 'RestoreFormer':
|
76 |
+
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
77 |
+
self.gfpgan = RestoreFormer()
|
78 |
+
# initialize face helper
|
79 |
+
self.face_helper = FaceRestoreHelper(
|
80 |
+
upscale,
|
81 |
+
face_size=512,
|
82 |
+
crop_ratio=(1, 1),
|
83 |
+
det_model='retinaface_resnet50',
|
84 |
+
save_ext='png',
|
85 |
+
use_parse=True,
|
86 |
+
device=self.device,
|
87 |
+
model_rootpath='gfpgan/weights')
|
88 |
+
|
89 |
+
if model_path.startswith('https://'):
|
90 |
+
model_path = load_file_from_url(
|
91 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
92 |
+
loadnet = torch.load(model_path)
|
93 |
+
if 'params_ema' in loadnet:
|
94 |
+
keyname = 'params_ema'
|
95 |
+
else:
|
96 |
+
keyname = 'params'
|
97 |
+
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
98 |
+
self.gfpgan.eval()
|
99 |
+
self.gfpgan = self.gfpgan.to(self.device)
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
|
103 |
+
self.face_helper.clean_all()
|
104 |
+
|
105 |
+
if has_aligned: # the inputs are already aligned
|
106 |
+
img = cv2.resize(img, (512, 512))
|
107 |
+
self.face_helper.cropped_faces = [img]
|
108 |
+
else:
|
109 |
+
self.face_helper.read_image(img)
|
110 |
+
# get face landmarks for each face
|
111 |
+
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
112 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
113 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
114 |
+
# align and warp each face
|
115 |
+
self.face_helper.align_warp_face()
|
116 |
+
|
117 |
+
# face restoration
|
118 |
+
for cropped_face in self.face_helper.cropped_faces:
|
119 |
+
# prepare data
|
120 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
121 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
122 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
123 |
+
|
124 |
+
try:
|
125 |
+
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
126 |
+
# convert to image
|
127 |
+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
128 |
+
except RuntimeError as error:
|
129 |
+
print(f'\tFailed inference for GFPGAN: {error}.')
|
130 |
+
restored_face = cropped_face
|
131 |
+
|
132 |
+
restored_face = restored_face.astype('uint8')
|
133 |
+
self.face_helper.add_restored_face(restored_face)
|
134 |
+
|
135 |
+
if not has_aligned and paste_back:
|
136 |
+
# upsample the background
|
137 |
+
if self.bg_upsampler is not None:
|
138 |
+
# Now only support RealESRGAN for upsampling background
|
139 |
+
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
140 |
+
else:
|
141 |
+
bg_img = None
|
142 |
+
|
143 |
+
self.face_helper.get_inverse_affine(None)
|
144 |
+
# paste each restored face to the input image
|
145 |
+
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
146 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
147 |
+
else:
|
148 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
gfpgan/version.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GENERATED VERSION FILE
|
2 |
+
# TIME: Sat May 25 17:25:12 2024
|
3 |
+
__version__ = '1.3.8'
|
4 |
+
__gitsha__ = '7552a77'
|
5 |
+
version_info = (1, 3, 8)
|
gfpgan/weights/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Weights
|
2 |
+
|
3 |
+
Put the downloaded weights to this folder.
|
gfpgan/weights/detection_Resnet50_Final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
3 |
+
size 109497761
|
gfpgan/weights/parsing_parsenet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
|
3 |
+
size 85331193
|