Xintao commited on
Commit
d0fef57
1 Parent(s): d561a66

add GFPGAN

Browse files
app.py CHANGED
@@ -1,8 +1,82 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
7
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
- iface.launch()
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
  import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from realesrgan_utils import RealESRGANer
9
+ from srvgg_arch import SRVGGNetCompact
10
+
11
+ os.system("pip freeze")
12
+ os.system(
13
+ "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights")
14
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./weights")
15
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./weights")
16
+
17
+ torch.hub.download_url_to_file(
18
+ 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
19
+ 'lincoln.jpg')
20
+ torch.hub.download_url_to_file('https://upload.wikimedia.org/wikipedia/commons/5/50/Albert_Einstein_%28Nobel%29.png',
21
+ 'einstein.png')
22
+ torch.hub.download_url_to_file(
23
+ 'https://upload.wikimedia.org/wikipedia/commons/thumb/9/9d/Thomas_Edison2.jpg/1024px-Thomas_Edison2.jpg',
24
+ 'edison.jpg')
25
+ torch.hub.download_url_to_file(
26
+ 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a9/Henry_Ford_1888.jpg/1024px-Henry_Ford_1888.jpg',
27
+ 'Henry.jpg')
28
+ torch.hub.download_url_to_file(
29
+ 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/06/Frida_Kahlo%2C_by_Guillermo_Kahlo.jpg/800px-Frida_Kahlo%2C_by_Guillermo_Kahlo.jpg',
30
+ 'Frida.jpg')
31
+
32
+ # determine models according to model names
33
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
34
+ netscale = 4
35
+ model_path = os.path.join('weights', 'realesr-general-x4v3.pth')
36
+
37
+ # restorer
38
+ upsampler = RealESRGANer(scale=netscale, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=True)
39
+
40
+ # Use GFPGAN for face enhancement
41
+ from gfpgan_utils import GFPGANer
42
+
43
+ face_enhancer = GFPGANer(
44
+ model_path='weights/GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
45
+ os.makedirs('output', exist_ok=True)
46
+
47
+
48
+ def inference(img):
49
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
50
+
51
+ h, w = img.shape[0:2]
52
+ if h < 400:
53
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
54
+
55
+ if len(img.shape) == 3 and img.shape[2] == 4:
56
+ img_mode = 'RGBA'
57
+ else:
58
+ img_mode = None
59
+
60
+ try:
61
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
62
+ except RuntimeError as error:
63
+ print('Error', error)
64
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
65
+ else:
66
+ extension = extension[1:]
67
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
68
+ extension = 'png'
69
 
70
+ return Image.fromarray(output)
71
 
 
 
72
 
73
+ title = "GFP-GAN"
74
+ description = "Gradio demo for GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once"
75
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2101.04061' target='_blank'>Towards Real-World Blind Face Restoration with Generative Facial Prior</a> | <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>"
76
+ gr.Interface(
77
+ inference, [gr.inputs.Image(type="filepath", label="Input")],
78
+ gr.outputs.Image(type="pil", label="Output"),
79
+ title=title,
80
+ description=description,
81
+ article=article,
82
+ examples=[['lincoln.jpg'], ['einstein.png'], ['edison.jpg'], ['Henry.jpg'], ['Frida.jpg']]).launch()
gfpgan_utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+ from basicsr.utils import img2tensor, tensor2img
6
+ from basicsr.utils.download_util import load_file_from_url
7
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
8
+ from torchvision.transforms.functional import normalize
9
+
10
+ from gfpganv1_clean_arch import GFPGANv1Clean
11
+
12
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+
15
+ class GFPGANer():
16
+ """Helper for restoration with GFPGAN.
17
+
18
+ It will detect and crop faces, and then resize the faces to 512x512.
19
+ GFPGAN is used to restored the resized faces.
20
+ The background is upsampled with the bg_upsampler.
21
+ Finally, the faces will be pasted back to the upsample background image.
22
+
23
+ Args:
24
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
25
+ upscale (float): The upscale of the final output. Default: 2.
26
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
27
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
28
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
29
+ """
30
+
31
+ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
32
+ self.upscale = upscale
33
+ self.bg_upsampler = bg_upsampler
34
+
35
+ # initialize model
36
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
37
+ # initialize the GFP-GAN
38
+ self.gfpgan = GFPGANv1Clean(
39
+ out_size=512,
40
+ num_style_feat=512,
41
+ channel_multiplier=channel_multiplier,
42
+ decoder_load_path=None,
43
+ fix_decoder=False,
44
+ num_mlp=8,
45
+ input_is_latent=True,
46
+ different_w=True,
47
+ narrow=1,
48
+ sft_half=True)
49
+
50
+ # initialize face helper
51
+ self.face_helper = FaceRestoreHelper(
52
+ upscale,
53
+ face_size=512,
54
+ crop_ratio=(1, 1),
55
+ det_model='retinaface_resnet50',
56
+ save_ext='png',
57
+ use_parse=True,
58
+ device=self.device)
59
+
60
+ if model_path.startswith('https://'):
61
+ model_path = load_file_from_url(
62
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
63
+ loadnet = torch.load(model_path)
64
+ if 'params_ema' in loadnet:
65
+ keyname = 'params_ema'
66
+ else:
67
+ keyname = 'params'
68
+ self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
69
+ self.gfpgan.eval()
70
+ self.gfpgan = self.gfpgan.to(self.device)
71
+
72
+ @torch.no_grad()
73
+ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
74
+ self.face_helper.clean_all()
75
+
76
+ if has_aligned: # the inputs are already aligned
77
+ img = cv2.resize(img, (512, 512))
78
+ self.face_helper.cropped_faces = [img]
79
+ else:
80
+ self.face_helper.read_image(img)
81
+ # get face landmarks for each face
82
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
83
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
84
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
85
+ # align and warp each face
86
+ self.face_helper.align_warp_face()
87
+
88
+ # face restoration
89
+ for cropped_face in self.face_helper.cropped_faces:
90
+ # prepare data
91
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
92
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
93
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
94
+
95
+ try:
96
+ output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
97
+ # convert to image
98
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
99
+ except RuntimeError as error:
100
+ print(f'\tFailed inference for GFPGAN: {error}.')
101
+ restored_face = cropped_face
102
+
103
+ restored_face = restored_face.astype('uint8')
104
+ self.face_helper.add_restored_face(restored_face)
105
+
106
+ if not has_aligned and paste_back:
107
+ # upsample the background
108
+ if self.bg_upsampler is not None:
109
+ # Now only support RealESRGAN for upsampling background
110
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
111
+ else:
112
+ bg_img = None
113
+
114
+ self.face_helper.get_inverse_affine(None)
115
+ # paste each restored face to the input image
116
+ restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
117
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
118
+ else:
119
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
gfpganv1_clean_arch.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from stylegan2_clean_arch import StyleGAN2GeneratorClean
10
+
11
+
12
+ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
13
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
+
15
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
16
+
17
+ Args:
18
+ out_size (int): The spatial size of outputs.
19
+ num_style_feat (int): Channel number of style features. Default: 512.
20
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
21
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
22
+ narrow (float): The narrow ratio for channels. Default: 1.
23
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
24
+ """
25
+
26
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
27
+ super(StyleGAN2GeneratorCSFT, self).__init__(
28
+ out_size,
29
+ num_style_feat=num_style_feat,
30
+ num_mlp=num_mlp,
31
+ channel_multiplier=channel_multiplier,
32
+ narrow=narrow)
33
+ self.sft_half = sft_half
34
+
35
+ def forward(self,
36
+ styles,
37
+ conditions,
38
+ input_is_latent=False,
39
+ noise=None,
40
+ randomize_noise=True,
41
+ truncation=1,
42
+ truncation_latent=None,
43
+ inject_index=None,
44
+ return_latents=False):
45
+ """Forward function for StyleGAN2GeneratorCSFT.
46
+
47
+ Args:
48
+ styles (list[Tensor]): Sample codes of styles.
49
+ conditions (list[Tensor]): SFT conditions to generators.
50
+ input_is_latent (bool): Whether input is latent style. Default: False.
51
+ noise (Tensor | None): Input noise or None. Default: None.
52
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
53
+ truncation (float): The truncation ratio. Default: 1.
54
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
55
+ inject_index (int | None): The injection index for mixing noise. Default: None.
56
+ return_latents (bool): Whether to return style latents. Default: False.
57
+ """
58
+ # style codes -> latents with Style MLP layer
59
+ if not input_is_latent:
60
+ styles = [self.style_mlp(s) for s in styles]
61
+ # noises
62
+ if noise is None:
63
+ if randomize_noise:
64
+ noise = [None] * self.num_layers # for each style conv layer
65
+ else: # use the stored noise
66
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
67
+ # style truncation
68
+ if truncation < 1:
69
+ style_truncation = []
70
+ for style in styles:
71
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
72
+ styles = style_truncation
73
+ # get style latents with injection
74
+ if len(styles) == 1:
75
+ inject_index = self.num_latent
76
+
77
+ if styles[0].ndim < 3:
78
+ # repeat latent code for all the layers
79
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
80
+ else: # used for encoder with different latent code for each layer
81
+ latent = styles[0]
82
+ elif len(styles) == 2: # mixing noises
83
+ if inject_index is None:
84
+ inject_index = random.randint(1, self.num_latent - 1)
85
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
86
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
87
+ latent = torch.cat([latent1, latent2], 1)
88
+
89
+ # main generation
90
+ out = self.constant_input(latent.shape[0])
91
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
92
+ skip = self.to_rgb1(out, latent[:, 1])
93
+
94
+ i = 1
95
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
96
+ noise[2::2], self.to_rgbs):
97
+ out = conv1(out, latent[:, i], noise=noise1)
98
+
99
+ # the conditions may have fewer levels
100
+ if i < len(conditions):
101
+ # SFT part to combine the conditions
102
+ if self.sft_half: # only apply SFT to half of the channels
103
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
104
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
105
+ out = torch.cat([out_same, out_sft], dim=1)
106
+ else: # apply SFT to all the channels
107
+ out = out * conditions[i - 1] + conditions[i]
108
+
109
+ out = conv2(out, latent[:, i + 1], noise=noise2)
110
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
111
+ i += 2
112
+
113
+ image = skip
114
+
115
+ if return_latents:
116
+ return image, latent
117
+ else:
118
+ return image, None
119
+
120
+
121
+ class ResBlock(nn.Module):
122
+ """Residual block with bilinear upsampling/downsampling.
123
+
124
+ Args:
125
+ in_channels (int): Channel number of the input.
126
+ out_channels (int): Channel number of the output.
127
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
128
+ """
129
+
130
+ def __init__(self, in_channels, out_channels, mode='down'):
131
+ super(ResBlock, self).__init__()
132
+
133
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
134
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
135
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
136
+ if mode == 'down':
137
+ self.scale_factor = 0.5
138
+ elif mode == 'up':
139
+ self.scale_factor = 2
140
+
141
+ def forward(self, x):
142
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
143
+ # upsample/downsample
144
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
145
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
146
+ # skip
147
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
148
+ skip = self.skip(x)
149
+ out = out + skip
150
+ return out
151
+
152
+
153
+ @ARCH_REGISTRY.register()
154
+ class GFPGANv1Clean(nn.Module):
155
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
156
+
157
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
158
+
159
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
160
+
161
+ Args:
162
+ out_size (int): The spatial size of outputs.
163
+ num_style_feat (int): Channel number of style features. Default: 512.
164
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
165
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
166
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
167
+
168
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
169
+ input_is_latent (bool): Whether input is latent style. Default: False.
170
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
171
+ narrow (float): The narrow ratio for channels. Default: 1.
172
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ out_size,
178
+ num_style_feat=512,
179
+ channel_multiplier=1,
180
+ decoder_load_path=None,
181
+ fix_decoder=True,
182
+ # for stylegan decoder
183
+ num_mlp=8,
184
+ input_is_latent=False,
185
+ different_w=False,
186
+ narrow=1,
187
+ sft_half=False):
188
+
189
+ super(GFPGANv1Clean, self).__init__()
190
+ self.input_is_latent = input_is_latent
191
+ self.different_w = different_w
192
+ self.num_style_feat = num_style_feat
193
+
194
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
195
+ channels = {
196
+ '4': int(512 * unet_narrow),
197
+ '8': int(512 * unet_narrow),
198
+ '16': int(512 * unet_narrow),
199
+ '32': int(512 * unet_narrow),
200
+ '64': int(256 * channel_multiplier * unet_narrow),
201
+ '128': int(128 * channel_multiplier * unet_narrow),
202
+ '256': int(64 * channel_multiplier * unet_narrow),
203
+ '512': int(32 * channel_multiplier * unet_narrow),
204
+ '1024': int(16 * channel_multiplier * unet_narrow)
205
+ }
206
+
207
+ self.log_size = int(math.log(out_size, 2))
208
+ first_out_size = 2**(int(math.log(out_size, 2)))
209
+
210
+ self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
211
+
212
+ # downsample
213
+ in_channels = channels[f'{first_out_size}']
214
+ self.conv_body_down = nn.ModuleList()
215
+ for i in range(self.log_size, 2, -1):
216
+ out_channels = channels[f'{2**(i - 1)}']
217
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
218
+ in_channels = out_channels
219
+
220
+ self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
221
+
222
+ # upsample
223
+ in_channels = channels['4']
224
+ self.conv_body_up = nn.ModuleList()
225
+ for i in range(3, self.log_size + 1):
226
+ out_channels = channels[f'{2**i}']
227
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
228
+ in_channels = out_channels
229
+
230
+ # to RGB
231
+ self.toRGB = nn.ModuleList()
232
+ for i in range(3, self.log_size + 1):
233
+ self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
234
+
235
+ if different_w:
236
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
237
+ else:
238
+ linear_out_channel = num_style_feat
239
+
240
+ self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
241
+
242
+ # the decoder: stylegan2 generator with SFT modulations
243
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
244
+ out_size=out_size,
245
+ num_style_feat=num_style_feat,
246
+ num_mlp=num_mlp,
247
+ channel_multiplier=channel_multiplier,
248
+ narrow=narrow,
249
+ sft_half=sft_half)
250
+
251
+ # load pre-trained stylegan2 model if necessary
252
+ if decoder_load_path:
253
+ self.stylegan_decoder.load_state_dict(
254
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
255
+ # fix decoder without updating params
256
+ if fix_decoder:
257
+ for _, param in self.stylegan_decoder.named_parameters():
258
+ param.requires_grad = False
259
+
260
+ # for SFT modulations (scale and shift)
261
+ self.condition_scale = nn.ModuleList()
262
+ self.condition_shift = nn.ModuleList()
263
+ for i in range(3, self.log_size + 1):
264
+ out_channels = channels[f'{2**i}']
265
+ if sft_half:
266
+ sft_out_channels = out_channels
267
+ else:
268
+ sft_out_channels = out_channels * 2
269
+ self.condition_scale.append(
270
+ nn.Sequential(
271
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
272
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
273
+ self.condition_shift.append(
274
+ nn.Sequential(
275
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
276
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
277
+
278
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
279
+ """Forward function for GFPGANv1Clean.
280
+
281
+ Args:
282
+ x (Tensor): Input images.
283
+ return_latents (bool): Whether to return style latents. Default: False.
284
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
285
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
286
+ """
287
+ conditions = []
288
+ unet_skips = []
289
+ out_rgbs = []
290
+
291
+ # encoder
292
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
293
+ for i in range(self.log_size - 2):
294
+ feat = self.conv_body_down[i](feat)
295
+ unet_skips.insert(0, feat)
296
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
297
+
298
+ # style code
299
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
300
+ if self.different_w:
301
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
302
+
303
+ # decode
304
+ for i in range(self.log_size - 2):
305
+ # add unet skip
306
+ feat = feat + unet_skips[i]
307
+ # ResUpLayer
308
+ feat = self.conv_body_up[i](feat)
309
+ # generate scale and shift for SFT layers
310
+ scale = self.condition_scale[i](feat)
311
+ conditions.append(scale.clone())
312
+ shift = self.condition_shift[i](feat)
313
+ conditions.append(shift.clone())
314
+ # generate rgb images
315
+ if return_rgb:
316
+ out_rgbs.append(self.toRGB[i](feat))
317
+
318
+ # decoder
319
+ image, _ = self.stylegan_decoder([style_code],
320
+ conditions,
321
+ return_latents=return_latents,
322
+ input_is_latent=self.input_is_latent,
323
+ randomize_noise=randomize_noise)
324
+
325
+ return image, out_rgbs
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
realesrgan_utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import queue
6
+ import threading
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from torch.nn import functional as F
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+
14
+ class RealESRGANer():
15
+ """A helper class for upsampling images with RealESRGAN.
16
+
17
+ Args:
18
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
19
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
20
+ model (nn.Module): The defined network. Default: None.
21
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
22
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
23
+ 0 denotes for do not use tile. Default: 0.
24
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
25
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
26
+ half (float): Whether to use half precision during inference. Default: False.
27
+ """
28
+
29
+ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
30
+ self.scale = scale
31
+ self.tile_size = tile
32
+ self.tile_pad = tile_pad
33
+ self.pre_pad = pre_pad
34
+ self.mod_scale = None
35
+ self.half = half
36
+
37
+ # initialize model
38
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
40
+ if model_path.startswith('https://'):
41
+ model_path = load_file_from_url(
42
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
43
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
44
+ # prefer to use params_ema
45
+ if 'params_ema' in loadnet:
46
+ keyname = 'params_ema'
47
+ else:
48
+ keyname = 'params'
49
+ model.load_state_dict(loadnet[keyname], strict=True)
50
+ model.eval()
51
+ self.model = model.to(self.device)
52
+ if self.half:
53
+ self.model = self.model.half()
54
+
55
+ def pre_process(self, img):
56
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
57
+ """
58
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
59
+ self.img = img.unsqueeze(0).to(self.device)
60
+ if self.half:
61
+ self.img = self.img.half()
62
+
63
+ # pre_pad
64
+ if self.pre_pad != 0:
65
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
66
+ # mod pad for divisible borders
67
+ if self.scale == 2:
68
+ self.mod_scale = 2
69
+ elif self.scale == 1:
70
+ self.mod_scale = 4
71
+ if self.mod_scale is not None:
72
+ self.mod_pad_h, self.mod_pad_w = 0, 0
73
+ _, _, h, w = self.img.size()
74
+ if (h % self.mod_scale != 0):
75
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
76
+ if (w % self.mod_scale != 0):
77
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
78
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
79
+
80
+ def process(self):
81
+ # model inference
82
+ self.output = self.model(self.img)
83
+
84
+ def tile_process(self):
85
+ """It will first crop input images to tiles, and then process each tile.
86
+ Finally, all the processed tiles are merged into one images.
87
+
88
+ Modified from: https://github.com/ata4/esrgan-launcher
89
+ """
90
+ batch, channel, height, width = self.img.shape
91
+ output_height = height * self.scale
92
+ output_width = width * self.scale
93
+ output_shape = (batch, channel, output_height, output_width)
94
+
95
+ # start with black image
96
+ self.output = self.img.new_zeros(output_shape)
97
+ tiles_x = math.ceil(width / self.tile_size)
98
+ tiles_y = math.ceil(height / self.tile_size)
99
+
100
+ # loop over all tiles
101
+ for y in range(tiles_y):
102
+ for x in range(tiles_x):
103
+ # extract tile from input image
104
+ ofs_x = x * self.tile_size
105
+ ofs_y = y * self.tile_size
106
+ # input tile area on total image
107
+ input_start_x = ofs_x
108
+ input_end_x = min(ofs_x + self.tile_size, width)
109
+ input_start_y = ofs_y
110
+ input_end_y = min(ofs_y + self.tile_size, height)
111
+
112
+ # input tile area on total image with padding
113
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
114
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
115
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
116
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
117
+
118
+ # input tile dimensions
119
+ input_tile_width = input_end_x - input_start_x
120
+ input_tile_height = input_end_y - input_start_y
121
+ tile_idx = y * tiles_x + x + 1
122
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
123
+
124
+ # upscale tile
125
+ try:
126
+ with torch.no_grad():
127
+ output_tile = self.model(input_tile)
128
+ except RuntimeError as error:
129
+ print('Error', error)
130
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
131
+
132
+ # output tile area on total image
133
+ output_start_x = input_start_x * self.scale
134
+ output_end_x = input_end_x * self.scale
135
+ output_start_y = input_start_y * self.scale
136
+ output_end_y = input_end_y * self.scale
137
+
138
+ # output tile area without padding
139
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
140
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
141
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
142
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
143
+
144
+ # put tile into output image
145
+ self.output[:, :, output_start_y:output_end_y,
146
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
147
+ output_start_x_tile:output_end_x_tile]
148
+
149
+ def post_process(self):
150
+ # remove extra pad
151
+ if self.mod_scale is not None:
152
+ _, _, h, w = self.output.size()
153
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
154
+ # remove prepad
155
+ if self.pre_pad != 0:
156
+ _, _, h, w = self.output.size()
157
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
158
+ return self.output
159
+
160
+ @torch.no_grad()
161
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
162
+ h_input, w_input = img.shape[0:2]
163
+ # img: numpy
164
+ img = img.astype(np.float32)
165
+ if np.max(img) > 256: # 16-bit image
166
+ max_range = 65535
167
+ print('\tInput is a 16-bit image')
168
+ else:
169
+ max_range = 255
170
+ img = img / max_range
171
+ if len(img.shape) == 2: # gray image
172
+ img_mode = 'L'
173
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
174
+ elif img.shape[2] == 4: # RGBA image with alpha channel
175
+ img_mode = 'RGBA'
176
+ alpha = img[:, :, 3]
177
+ img = img[:, :, 0:3]
178
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
179
+ if alpha_upsampler == 'realesrgan':
180
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
181
+ else:
182
+ img_mode = 'RGB'
183
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
184
+
185
+ # ------------------- process image (without the alpha channel) ------------------- #
186
+ self.pre_process(img)
187
+ if self.tile_size > 0:
188
+ self.tile_process()
189
+ else:
190
+ self.process()
191
+ output_img = self.post_process()
192
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
193
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
194
+ if img_mode == 'L':
195
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
196
+
197
+ # ------------------- process the alpha channel if necessary ------------------- #
198
+ if img_mode == 'RGBA':
199
+ if alpha_upsampler == 'realesrgan':
200
+ self.pre_process(alpha)
201
+ if self.tile_size > 0:
202
+ self.tile_process()
203
+ else:
204
+ self.process()
205
+ output_alpha = self.post_process()
206
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
207
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
208
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
209
+ else: # use the cv2 resize for alpha channel
210
+ h, w = alpha.shape[0:2]
211
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
212
+
213
+ # merge the alpha channel
214
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
215
+ output_img[:, :, 3] = output_alpha
216
+
217
+ # ------------------------------ return ------------------------------ #
218
+ if max_range == 65535: # 16-bit image
219
+ output = (output_img * 65535.0).round().astype(np.uint16)
220
+ else:
221
+ output = (output_img * 255.0).round().astype(np.uint8)
222
+
223
+ if outscale is not None and outscale != float(self.scale):
224
+ output = cv2.resize(
225
+ output, (
226
+ int(w_input * outscale),
227
+ int(h_input * outscale),
228
+ ), interpolation=cv2.INTER_LANCZOS4)
229
+
230
+ return output, img_mode
231
+
232
+
233
+ class PrefetchReader(threading.Thread):
234
+ """Prefetch images.
235
+
236
+ Args:
237
+ img_list (list[str]): A image list of image paths to be read.
238
+ num_prefetch_queue (int): Number of prefetch queue.
239
+ """
240
+
241
+ def __init__(self, img_list, num_prefetch_queue):
242
+ super().__init__()
243
+ self.que = queue.Queue(num_prefetch_queue)
244
+ self.img_list = img_list
245
+
246
+ def run(self):
247
+ for img_path in self.img_list:
248
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
249
+ self.que.put(img)
250
+
251
+ self.que.put(None)
252
+
253
+ def __next__(self):
254
+ next_item = self.que.get()
255
+ if next_item is None:
256
+ raise StopIteration
257
+ return next_item
258
+
259
+ def __iter__(self):
260
+ return self
261
+
262
+
263
+ class IOConsumer(threading.Thread):
264
+
265
+ def __init__(self, opt, que, qid):
266
+ super().__init__()
267
+ self._queue = que
268
+ self.qid = qid
269
+ self.opt = opt
270
+
271
+ def run(self):
272
+ while True:
273
+ msg = self._queue.get()
274
+ if isinstance(msg, str) and msg == 'quit':
275
+ break
276
+
277
+ output = msg['output']
278
+ save_path = msg['save_path']
279
+ cv2.imwrite(save_path, output)
280
+ print(f'IO worker {self.qid} is done.')
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.7
2
+ numpy
3
+ opencv-python
4
+ torchvision
5
+ scipy
6
+ tqdm
7
+ basicsr>=1.4.1
8
+ facexlib>=0.2.4
9
+ lmdb
10
+ pyyaml
11
+ yapf
srvgg_arch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class SRVGGNetCompact(nn.Module):
6
+ """A compact VGG-style network structure for super-resolution.
7
+
8
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
9
+ conducted on the HR feature space.
10
+
11
+ Args:
12
+ num_in_ch (int): Channel number of inputs. Default: 3.
13
+ num_out_ch (int): Channel number of outputs. Default: 3.
14
+ num_feat (int): Channel number of intermediate features. Default: 64.
15
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
16
+ upscale (int): Upsampling factor. Default: 4.
17
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
18
+ """
19
+
20
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
21
+ super(SRVGGNetCompact, self).__init__()
22
+ self.num_in_ch = num_in_ch
23
+ self.num_out_ch = num_out_ch
24
+ self.num_feat = num_feat
25
+ self.num_conv = num_conv
26
+ self.upscale = upscale
27
+ self.act_type = act_type
28
+
29
+ self.body = nn.ModuleList()
30
+ # the first conv
31
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
32
+ # the first activation
33
+ if act_type == 'relu':
34
+ activation = nn.ReLU(inplace=True)
35
+ elif act_type == 'prelu':
36
+ activation = nn.PReLU(num_parameters=num_feat)
37
+ elif act_type == 'leakyrelu':
38
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
39
+ self.body.append(activation)
40
+
41
+ # the body structure
42
+ for _ in range(num_conv):
43
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
44
+ # activation
45
+ if act_type == 'relu':
46
+ activation = nn.ReLU(inplace=True)
47
+ elif act_type == 'prelu':
48
+ activation = nn.PReLU(num_parameters=num_feat)
49
+ elif act_type == 'leakyrelu':
50
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
51
+ self.body.append(activation)
52
+
53
+ # the last conv
54
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
55
+ # upsample
56
+ self.upsampler = nn.PixelShuffle(upscale)
57
+
58
+ def forward(self, x):
59
+ out = x
60
+ for i in range(0, len(self.body)):
61
+ out = self.body[i](out)
62
+
63
+ out = self.upsampler(out)
64
+ # add the nearest upsampled image, so that the network learns the residual
65
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
66
+ out += base
67
+ return out
stylegan2_clean_arch.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from basicsr.archs.arch_util import default_init_weights
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class NormStyleCode(nn.Module):
12
+
13
+ def forward(self, x):
14
+ """Normalize the style codes.
15
+
16
+ Args:
17
+ x (Tensor): Style codes with shape (b, c).
18
+
19
+ Returns:
20
+ Tensor: Normalized tensor.
21
+ """
22
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
23
+
24
+
25
+ class ModulatedConv2d(nn.Module):
26
+ """Modulated Conv2d used in StyleGAN2.
27
+
28
+ There is no bias in ModulatedConv2d.
29
+
30
+ Args:
31
+ in_channels (int): Channel number of the input.
32
+ out_channels (int): Channel number of the output.
33
+ kernel_size (int): Size of the convolving kernel.
34
+ num_style_feat (int): Channel number of style features.
35
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
36
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
37
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
38
+ """
39
+
40
+ def __init__(self,
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size,
44
+ num_style_feat,
45
+ demodulate=True,
46
+ sample_mode=None,
47
+ eps=1e-8):
48
+ super(ModulatedConv2d, self).__init__()
49
+ self.in_channels = in_channels
50
+ self.out_channels = out_channels
51
+ self.kernel_size = kernel_size
52
+ self.demodulate = demodulate
53
+ self.sample_mode = sample_mode
54
+ self.eps = eps
55
+
56
+ # modulation inside each modulated conv
57
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
58
+ # initialization
59
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
60
+
61
+ self.weight = nn.Parameter(
62
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
63
+ math.sqrt(in_channels * kernel_size**2))
64
+ self.padding = kernel_size // 2
65
+
66
+ def forward(self, x, style):
67
+ """Forward function.
68
+
69
+ Args:
70
+ x (Tensor): Tensor with shape (b, c, h, w).
71
+ style (Tensor): Tensor with shape (b, num_style_feat).
72
+
73
+ Returns:
74
+ Tensor: Modulated tensor after convolution.
75
+ """
76
+ b, c, h, w = x.shape # c = c_in
77
+ # weight modulation
78
+ style = self.modulation(style).view(b, 1, c, 1, 1)
79
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
80
+ weight = self.weight * style # (b, c_out, c_in, k, k)
81
+
82
+ if self.demodulate:
83
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
84
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
85
+
86
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
87
+
88
+ # upsample or downsample if necessary
89
+ if self.sample_mode == 'upsample':
90
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
91
+ elif self.sample_mode == 'downsample':
92
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
93
+
94
+ b, c, h, w = x.shape
95
+ x = x.view(1, b * c, h, w)
96
+ # weight: (b*c_out, c_in, k, k), groups=b
97
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
98
+ out = out.view(b, self.out_channels, *out.shape[2:4])
99
+
100
+ return out
101
+
102
+ def __repr__(self):
103
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
104
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
105
+
106
+
107
+ class StyleConv(nn.Module):
108
+ """Style conv used in StyleGAN2.
109
+
110
+ Args:
111
+ in_channels (int): Channel number of the input.
112
+ out_channels (int): Channel number of the output.
113
+ kernel_size (int): Size of the convolving kernel.
114
+ num_style_feat (int): Channel number of style features.
115
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
116
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
117
+ """
118
+
119
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
120
+ super(StyleConv, self).__init__()
121
+ self.modulated_conv = ModulatedConv2d(
122
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
123
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
124
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
125
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
126
+
127
+ def forward(self, x, style, noise=None):
128
+ # modulate
129
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
130
+ # noise injection
131
+ if noise is None:
132
+ b, _, h, w = out.shape
133
+ noise = out.new_empty(b, 1, h, w).normal_()
134
+ out = out + self.weight * noise
135
+ # add bias
136
+ out = out + self.bias
137
+ # activation
138
+ out = self.activate(out)
139
+ return out
140
+
141
+
142
+ class ToRGB(nn.Module):
143
+ """To RGB (image space) from features.
144
+
145
+ Args:
146
+ in_channels (int): Channel number of input.
147
+ num_style_feat (int): Channel number of style features.
148
+ upsample (bool): Whether to upsample. Default: True.
149
+ """
150
+
151
+ def __init__(self, in_channels, num_style_feat, upsample=True):
152
+ super(ToRGB, self).__init__()
153
+ self.upsample = upsample
154
+ self.modulated_conv = ModulatedConv2d(
155
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
156
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
157
+
158
+ def forward(self, x, style, skip=None):
159
+ """Forward function.
160
+
161
+ Args:
162
+ x (Tensor): Feature tensor with shape (b, c, h, w).
163
+ style (Tensor): Tensor with shape (b, num_style_feat).
164
+ skip (Tensor): Base/skip tensor. Default: None.
165
+
166
+ Returns:
167
+ Tensor: RGB images.
168
+ """
169
+ out = self.modulated_conv(x, style)
170
+ out = out + self.bias
171
+ if skip is not None:
172
+ if self.upsample:
173
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
174
+ out = out + skip
175
+ return out
176
+
177
+
178
+ class ConstantInput(nn.Module):
179
+ """Constant input.
180
+
181
+ Args:
182
+ num_channel (int): Channel number of constant input.
183
+ size (int): Spatial size of constant input.
184
+ """
185
+
186
+ def __init__(self, num_channel, size):
187
+ super(ConstantInput, self).__init__()
188
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
189
+
190
+ def forward(self, batch):
191
+ out = self.weight.repeat(batch, 1, 1, 1)
192
+ return out
193
+
194
+
195
+ @ARCH_REGISTRY.register()
196
+ class StyleGAN2GeneratorClean(nn.Module):
197
+ """Clean version of StyleGAN2 Generator.
198
+
199
+ Args:
200
+ out_size (int): The spatial size of outputs.
201
+ num_style_feat (int): Channel number of style features. Default: 512.
202
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
203
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
204
+ narrow (float): Narrow ratio for channels. Default: 1.0.
205
+ """
206
+
207
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
208
+ super(StyleGAN2GeneratorClean, self).__init__()
209
+ # Style MLP layers
210
+ self.num_style_feat = num_style_feat
211
+ style_mlp_layers = [NormStyleCode()]
212
+ for i in range(num_mlp):
213
+ style_mlp_layers.extend(
214
+ [nn.Linear(num_style_feat, num_style_feat, bias=True),
215
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)])
216
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
217
+ # initialization
218
+ default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
219
+
220
+ # channel list
221
+ channels = {
222
+ '4': int(512 * narrow),
223
+ '8': int(512 * narrow),
224
+ '16': int(512 * narrow),
225
+ '32': int(512 * narrow),
226
+ '64': int(256 * channel_multiplier * narrow),
227
+ '128': int(128 * channel_multiplier * narrow),
228
+ '256': int(64 * channel_multiplier * narrow),
229
+ '512': int(32 * channel_multiplier * narrow),
230
+ '1024': int(16 * channel_multiplier * narrow)
231
+ }
232
+ self.channels = channels
233
+
234
+ self.constant_input = ConstantInput(channels['4'], size=4)
235
+ self.style_conv1 = StyleConv(
236
+ channels['4'],
237
+ channels['4'],
238
+ kernel_size=3,
239
+ num_style_feat=num_style_feat,
240
+ demodulate=True,
241
+ sample_mode=None)
242
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
243
+
244
+ self.log_size = int(math.log(out_size, 2))
245
+ self.num_layers = (self.log_size - 2) * 2 + 1
246
+ self.num_latent = self.log_size * 2 - 2
247
+
248
+ self.style_convs = nn.ModuleList()
249
+ self.to_rgbs = nn.ModuleList()
250
+ self.noises = nn.Module()
251
+
252
+ in_channels = channels['4']
253
+ # noise
254
+ for layer_idx in range(self.num_layers):
255
+ resolution = 2**((layer_idx + 5) // 2)
256
+ shape = [1, 1, resolution, resolution]
257
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
258
+ # style convs and to_rgbs
259
+ for i in range(3, self.log_size + 1):
260
+ out_channels = channels[f'{2**i}']
261
+ self.style_convs.append(
262
+ StyleConv(
263
+ in_channels,
264
+ out_channels,
265
+ kernel_size=3,
266
+ num_style_feat=num_style_feat,
267
+ demodulate=True,
268
+ sample_mode='upsample'))
269
+ self.style_convs.append(
270
+ StyleConv(
271
+ out_channels,
272
+ out_channels,
273
+ kernel_size=3,
274
+ num_style_feat=num_style_feat,
275
+ demodulate=True,
276
+ sample_mode=None))
277
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
278
+ in_channels = out_channels
279
+
280
+ def make_noise(self):
281
+ """Make noise for noise injection."""
282
+ device = self.constant_input.weight.device
283
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
284
+
285
+ for i in range(3, self.log_size + 1):
286
+ for _ in range(2):
287
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
288
+
289
+ return noises
290
+
291
+ def get_latent(self, x):
292
+ return self.style_mlp(x)
293
+
294
+ def mean_latent(self, num_latent):
295
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
296
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
297
+ return latent
298
+
299
+ def forward(self,
300
+ styles,
301
+ input_is_latent=False,
302
+ noise=None,
303
+ randomize_noise=True,
304
+ truncation=1,
305
+ truncation_latent=None,
306
+ inject_index=None,
307
+ return_latents=False):
308
+ """Forward function for StyleGAN2GeneratorClean.
309
+
310
+ Args:
311
+ styles (list[Tensor]): Sample codes of styles.
312
+ input_is_latent (bool): Whether input is latent style. Default: False.
313
+ noise (Tensor | None): Input noise or None. Default: None.
314
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
315
+ truncation (float): The truncation ratio. Default: 1.
316
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
317
+ inject_index (int | None): The injection index for mixing noise. Default: None.
318
+ return_latents (bool): Whether to return style latents. Default: False.
319
+ """
320
+ # style codes -> latents with Style MLP layer
321
+ if not input_is_latent:
322
+ styles = [self.style_mlp(s) for s in styles]
323
+ # noises
324
+ if noise is None:
325
+ if randomize_noise:
326
+ noise = [None] * self.num_layers # for each style conv layer
327
+ else: # use the stored noise
328
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
329
+ # style truncation
330
+ if truncation < 1:
331
+ style_truncation = []
332
+ for style in styles:
333
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
334
+ styles = style_truncation
335
+ # get style latents with injection
336
+ if len(styles) == 1:
337
+ inject_index = self.num_latent
338
+
339
+ if styles[0].ndim < 3:
340
+ # repeat latent code for all the layers
341
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
342
+ else: # used for encoder with different latent code for each layer
343
+ latent = styles[0]
344
+ elif len(styles) == 2: # mixing noises
345
+ if inject_index is None:
346
+ inject_index = random.randint(1, self.num_latent - 1)
347
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
348
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
349
+ latent = torch.cat([latent1, latent2], 1)
350
+
351
+ # main generation
352
+ out = self.constant_input(latent.shape[0])
353
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
354
+ skip = self.to_rgb1(out, latent[:, 1])
355
+
356
+ i = 1
357
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
358
+ noise[2::2], self.to_rgbs):
359
+ out = conv1(out, latent[:, i], noise=noise1)
360
+ out = conv2(out, latent[:, i + 1], noise=noise2)
361
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
362
+ i += 2
363
+
364
+ image = skip
365
+
366
+ if return_latents:
367
+ return image, latent
368
+ else:
369
+ return image, None
weights/PutWeightsHere ADDED
File without changes