Spaces:
Paused
Paused
Upload model and app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +126 -0
- checkpoints/30_net_gen.pth +3 -0
- checkpoints/BFM.zip +3 -0
- checkpoints/BFM/.gitkeep +0 -0
- checkpoints/BFM/01_MorphableModel.mat +3 -0
- checkpoints/BFM/BFM_exp_idx.mat +3 -0
- checkpoints/BFM/BFM_front_idx.mat +3 -0
- checkpoints/BFM/BFM_model_front.mat +3 -0
- checkpoints/BFM/Exp_Pca.bin +3 -0
- checkpoints/BFM/facemodel_info.mat +3 -0
- checkpoints/BFM/select_vertex_id.mat +3 -0
- checkpoints/BFM/similarity_Lm3D_all.mat +3 -0
- checkpoints/BFM/std_exp.txt +1 -0
- checkpoints/DNet.pt +3 -0
- checkpoints/ENet.pth +3 -0
- checkpoints/GFPGANv1.3.pth +3 -0
- checkpoints/GPEN-BFR-512.pth +3 -0
- checkpoints/LNet.pth +3 -0
- checkpoints/ParseNet-latest.pth +3 -0
- checkpoints/RetinaFace-R50.pth +3 -0
- checkpoints/expression.mat +3 -0
- checkpoints/face3d_pretrain_epoch_20.pth +3 -0
- checkpoints/shape_predictor_68_face_landmarks.dat +3 -0
- inference.py +345 -0
- models/DNet.py +118 -0
- models/ENet.py +139 -0
- models/LNet.py +139 -0
- models/__init__.py +36 -0
- models/base_blocks.py +554 -0
- models/ffc.py +233 -0
- models/transformer.py +119 -0
- requirements.txt +21 -0
- third_part/GFPGAN/LICENSE +351 -0
- third_part/GFPGAN/gfpgan/__init__.py +8 -0
- third_part/GFPGAN/gfpgan/archs/__init__.py +10 -0
- third_part/GFPGAN/gfpgan/archs/arcface_arch.py +245 -0
- third_part/GFPGAN/gfpgan/archs/gfpgan_bilinear_arch.py +312 -0
- third_part/GFPGAN/gfpgan/archs/gfpganv1_arch.py +439 -0
- third_part/GFPGAN/gfpgan/archs/gfpganv1_clean_arch.py +324 -0
- third_part/GFPGAN/gfpgan/archs/stylegan2_bilinear_arch.py +613 -0
- third_part/GFPGAN/gfpgan/archs/stylegan2_clean_arch.py +368 -0
- third_part/GFPGAN/gfpgan/data/__init__.py +10 -0
- third_part/GFPGAN/gfpgan/data/ffhq_degradation_dataset.py +230 -0
- third_part/GFPGAN/gfpgan/models/__init__.py +10 -0
- third_part/GFPGAN/gfpgan/models/gfpgan_model.py +580 -0
- third_part/GFPGAN/gfpgan/train.py +11 -0
- third_part/GFPGAN/gfpgan/utils.py +143 -0
- third_part/GFPGAN/gfpgan/version.py +5 -0
- third_part/GFPGAN/gfpgan/weights/README.md +3 -0
- third_part/GFPGAN/options/train_gfpgan_v1.yml +216 -0
app.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
os.environ['MPLCONFIGDIR'] = os.getcwd() + "/configs/"
|
5 |
+
import gradio
|
6 |
+
import gradio as gr
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
|
11 |
+
|
12 |
+
def convert(segment_length, video, audio, progress=gradio.Progress()):
|
13 |
+
if segment_length is None:
|
14 |
+
segment_length=0
|
15 |
+
print(video, audio)
|
16 |
+
|
17 |
+
if segment_length != 0:
|
18 |
+
video_segments = cut_video_segments(video, segment_length)
|
19 |
+
audio_segments = cut_audio_segments(audio, segment_length)
|
20 |
+
else:
|
21 |
+
video_path = os.path.join('temp/video', os.path.basename(video))
|
22 |
+
shutil.move(video, video_path)
|
23 |
+
video_segments = [video_path]
|
24 |
+
audio_path = os.path.join('temp/audio', os.path.basename(audio))
|
25 |
+
shutil.move(audio, audio_path)
|
26 |
+
audio_segments = [audio_path]
|
27 |
+
|
28 |
+
processed_segments = []
|
29 |
+
for i, (video_seg, audio_seg) in progress.tqdm(enumerate(zip(video_segments, audio_segments))):
|
30 |
+
processed_output = process_segment(video_seg, audio_seg, i)
|
31 |
+
processed_segments.append(processed_output)
|
32 |
+
|
33 |
+
output_file = f"results/output_{random.randint(0,1000)}.mp4"
|
34 |
+
concatenate_videos(processed_segments, output_file)
|
35 |
+
|
36 |
+
# Remove temporary files
|
37 |
+
cleanup_temp_files(video_segments + audio_segments)
|
38 |
+
|
39 |
+
# Return the concatenated video file
|
40 |
+
return output_file
|
41 |
+
|
42 |
+
|
43 |
+
def cleanup_temp_files(file_list):
|
44 |
+
for file_path in file_list:
|
45 |
+
if os.path.isfile(file_path):
|
46 |
+
os.remove(file_path)
|
47 |
+
|
48 |
+
|
49 |
+
def cut_video_segments(video_file, segment_length):
|
50 |
+
temp_directory = 'temp/audio'
|
51 |
+
shutil.rmtree(temp_directory, ignore_errors=True)
|
52 |
+
shutil.os.makedirs(temp_directory, exist_ok=True)
|
53 |
+
segment_template = f"{temp_directory}/{random.randint(0,1000)}_%03d.mp4"
|
54 |
+
command = ["ffmpeg", "-i", video_file, "-c", "copy", "-f",
|
55 |
+
"segment", "-segment_time", str(segment_length), segment_template]
|
56 |
+
subprocess.run(command, check=True)
|
57 |
+
|
58 |
+
video_segments = [segment_template %
|
59 |
+
i for i in range(len(os.listdir(temp_directory)))]
|
60 |
+
return video_segments
|
61 |
+
|
62 |
+
|
63 |
+
def cut_audio_segments(audio_file, segment_length):
|
64 |
+
temp_directory = 'temp/video'
|
65 |
+
shutil.rmtree(temp_directory, ignore_errors=True)
|
66 |
+
shutil.os.makedirs(temp_directory, exist_ok=True)
|
67 |
+
segment_template = f"{temp_directory}/{random.randint(0,1000)}_%03d.mp3"
|
68 |
+
command = ["ffmpeg", "-i", audio_file, "-f", "segment",
|
69 |
+
"-segment_time", str(segment_length), segment_template]
|
70 |
+
subprocess.run(command, check=True)
|
71 |
+
|
72 |
+
audio_segments = [segment_template %
|
73 |
+
i for i in range(len(os.listdir(temp_directory)))]
|
74 |
+
return audio_segments
|
75 |
+
|
76 |
+
|
77 |
+
def process_segment(video_seg, audio_seg, i):
|
78 |
+
output_file = f"results/{random.randint(10,100000)}_{i}.mp4"
|
79 |
+
command = ["python", "inference.py", "--face", video_seg,
|
80 |
+
"--audio", audio_seg, "--outfile", output_file]
|
81 |
+
subprocess.run(command, check=True)
|
82 |
+
|
83 |
+
return output_file
|
84 |
+
|
85 |
+
|
86 |
+
def concatenate_videos(video_segments, output_file):
|
87 |
+
with open("segments.txt", "w") as file:
|
88 |
+
for segment in video_segments:
|
89 |
+
file.write(f"file '{segment}'\n")
|
90 |
+
command = ["ffmpeg", "-f", "concat", "-i",
|
91 |
+
"segments.txt", "-c", "copy", output_file]
|
92 |
+
subprocess.run(command, check=True)
|
93 |
+
|
94 |
+
|
95 |
+
with gradio.Blocks(
|
96 |
+
title="Audio-based Lip Synchronization",
|
97 |
+
theme=gr.themes.Base(
|
98 |
+
primary_hue=gr.themes.colors.green,
|
99 |
+
font=["Source Sans Pro", "Arial", "sans-serif"],
|
100 |
+
font_mono=['JetBrains mono', "Consolas", 'Courier New']
|
101 |
+
),
|
102 |
+
) as demo:
|
103 |
+
with gradio.Row():
|
104 |
+
gradio.Markdown("# Audio-based Lip Synchronization")
|
105 |
+
with gradio.Row():
|
106 |
+
with gradio.Column():
|
107 |
+
with gradio.Row():
|
108 |
+
seg = gradio.Number(
|
109 |
+
label="segment length (Second), 0 for no segmentation")
|
110 |
+
with gradio.Row():
|
111 |
+
with gradio.Column():
|
112 |
+
v = gradio.Video(label='SOurce Face')
|
113 |
+
|
114 |
+
with gradio.Column():
|
115 |
+
a = gradio.Audio(
|
116 |
+
type='filepath', label='Target Audio')
|
117 |
+
|
118 |
+
with gradio.Row():
|
119 |
+
btn = gradio.Button(value="Synthesize",variant="primary")
|
120 |
+
|
121 |
+
with gradio.Column():
|
122 |
+
o = gradio.Video(label="Output Video")
|
123 |
+
|
124 |
+
btn.click(fn=convert, inputs=[seg, v, a], outputs=[o])
|
125 |
+
|
126 |
+
demo.queue().launch()
|
checkpoints/30_net_gen.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4db83e1727128e2c5de27bc80d2929586535e04a709af45016a63e7cf7c46b0c
|
3 |
+
size 33877439
|
checkpoints/BFM.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:369eb3177ca5491fe04c2a9aba2d33a39642681f57796fbc611dab36f9a10656
|
3 |
+
size 404749663
|
checkpoints/BFM/.gitkeep
ADDED
File without changes
|
checkpoints/BFM/01_MorphableModel.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
|
3 |
+
size 240875364
|
checkpoints/BFM/BFM_exp_idx.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62752a2cab3eea148569fb07e367e03535b4ee04aa71ea1a9aed36486d26c612
|
3 |
+
size 91931
|
checkpoints/BFM/BFM_front_idx.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d285dd018563113496127df9c364800183172adb4d3e802f726085dab66b087
|
3 |
+
size 44880
|
checkpoints/BFM/BFM_model_front.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae3ff544aba3246c5f2c117f2be76fa44a7b76145326aae0bbfbfb564d4f82af
|
3 |
+
size 127170280
|
checkpoints/BFM/Exp_Pca.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726
|
3 |
+
size 51086404
|
checkpoints/BFM/facemodel_info.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:529398f76619ae7e22f43c25dd60a2473bcc2bcc8c894fd9c613c68624ce1c04
|
3 |
+
size 738861
|
checkpoints/BFM/select_vertex_id.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6877a7d634330f25bf1e81bc062b6507ee53ea183838e471fa21b613048fa36b
|
3 |
+
size 62299
|
checkpoints/BFM/similarity_Lm3D_all.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a
|
3 |
+
size 994
|
checkpoints/BFM/std_exp.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19
|
checkpoints/DNet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41220d2973c0ba2eab6e8f17ed00711aef5a0d76d19808f885dc0e3251df2e80
|
3 |
+
size 180424655
|
checkpoints/ENet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:967ee3ed857619cedd92b6407dc8a124cbfe763cc11cad58316fe21271a8928f
|
3 |
+
size 573261168
|
checkpoints/GFPGANv1.3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70
|
3 |
+
size 348632874
|
checkpoints/GPEN-BFR-512.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1002c41add95b0decad69604d80455576f7187dd99ca16bd611bcfd44c10b51
|
3 |
+
size 284085738
|
checkpoints/LNet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ae06fef0454c421b828cc53e8d4b9c92d990867a858ea7bb9661ab6cf6ab774
|
3 |
+
size 1534697728
|
checkpoints/ParseNet-latest.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
|
3 |
+
size 85331193
|
checkpoints/RetinaFace-R50.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
3 |
+
size 109497761
|
checkpoints/expression.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93e9d69eb46e866ed5cbb569ed2bdb3813254720fb0cb745d5b56181faf9aec5
|
3 |
+
size 1456
|
checkpoints/face3d_pretrain_epoch_20.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b
|
3 |
+
size 288860037
|
checkpoints/shape_predictor_68_face_landmarks.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
3 |
+
size 99693937
|
inference.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2, os, sys, subprocess, platform, torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
from scipy.io import loadmat
|
6 |
+
|
7 |
+
sys.path.insert(0, 'third_part')
|
8 |
+
sys.path.insert(0, 'third_part/GPEN')
|
9 |
+
sys.path.insert(0, 'third_part/GFPGAN')
|
10 |
+
|
11 |
+
# 3dmm extraction
|
12 |
+
from third_part.face3d.util.preprocess import align_img
|
13 |
+
from third_part.face3d.util.load_mats import load_lm3d
|
14 |
+
from third_part.face3d.extract_kp_videos import KeypointExtractor
|
15 |
+
# face enhancement
|
16 |
+
from third_part.GPEN.gpen_face_enhancer import FaceEnhancement
|
17 |
+
from third_part.GFPGAN.gfpgan import GFPGANer
|
18 |
+
# expression control
|
19 |
+
from third_part.ganimation_replicate.model.ganimation import GANimationModel
|
20 |
+
|
21 |
+
from utils import audio
|
22 |
+
from utils.ffhq_preprocess import Croper
|
23 |
+
from utils.alignment_stit import crop_faces, calc_alignment_coefficients, paste_image
|
24 |
+
from utils.inference_utils import Laplacian_Pyramid_Blending_with_mask, face_detect, load_model, options, split_coeff, \
|
25 |
+
trans_image, transform_semantic, find_crop_norm_ratio, load_face3d_net, exp_aus_dict
|
26 |
+
import warnings
|
27 |
+
warnings.filterwarnings("ignore")
|
28 |
+
|
29 |
+
args = options()
|
30 |
+
|
31 |
+
def main():
|
32 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
33 |
+
print('[Info] Using {} for inference.'.format(device))
|
34 |
+
os.makedirs(os.path.join('temp', args.tmp_dir), exist_ok=True)
|
35 |
+
|
36 |
+
enhancer = FaceEnhancement(base_dir='checkpoints', size=512, model='GPEN-BFR-512', use_sr=False, \
|
37 |
+
sr_model='rrdb_realesrnet_psnr', channel_multiplier=2, narrow=1, device=device)
|
38 |
+
restorer = GFPGANer(model_path='checkpoints/GFPGANv1.3.pth', upscale=1, arch='clean', \
|
39 |
+
channel_multiplier=2, bg_upsampler=None)
|
40 |
+
|
41 |
+
base_name = args.face.split('/')[-1]
|
42 |
+
if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
|
43 |
+
args.static = True
|
44 |
+
if not os.path.isfile(args.face):
|
45 |
+
raise ValueError('--face argument must be a valid path to video/image file')
|
46 |
+
elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
|
47 |
+
full_frames = [cv2.imread(args.face)]
|
48 |
+
fps = args.fps
|
49 |
+
else:
|
50 |
+
video_stream = cv2.VideoCapture(args.face)
|
51 |
+
fps = video_stream.get(cv2.CAP_PROP_FPS)
|
52 |
+
|
53 |
+
full_frames = []
|
54 |
+
while True:
|
55 |
+
still_reading, frame = video_stream.read()
|
56 |
+
if not still_reading:
|
57 |
+
video_stream.release()
|
58 |
+
break
|
59 |
+
y1, y2, x1, x2 = args.crop
|
60 |
+
if x2 == -1: x2 = frame.shape[1]
|
61 |
+
if y2 == -1: y2 = frame.shape[0]
|
62 |
+
frame = frame[y1:y2, x1:x2]
|
63 |
+
full_frames.append(frame)
|
64 |
+
|
65 |
+
print ("[Step 0] Number of frames available for inference: "+str(len(full_frames)))
|
66 |
+
# face detection & cropping, cropping the first frame as the style of FFHQ
|
67 |
+
croper = Croper('checkpoints/shape_predictor_68_face_landmarks.dat')
|
68 |
+
full_frames_RGB = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
|
69 |
+
full_frames_RGB, crop, quad = croper.crop(full_frames_RGB, xsize=512)
|
70 |
+
|
71 |
+
clx, cly, crx, cry = crop
|
72 |
+
lx, ly, rx, ry = quad
|
73 |
+
lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
|
74 |
+
oy1, oy2, ox1, ox2 = cly+ly, min(cly+ry, full_frames[0].shape[0]), clx+lx, min(clx+rx, full_frames[0].shape[1])
|
75 |
+
# original_size = (ox2 - ox1, oy2 - oy1)
|
76 |
+
frames_pil = [Image.fromarray(cv2.resize(frame,(256,256))) for frame in full_frames_RGB]
|
77 |
+
|
78 |
+
# get the landmark according to the detected face.
|
79 |
+
if not os.path.isfile('temp/'+base_name+'_landmarks.txt') or args.re_preprocess:
|
80 |
+
print('[Step 1] Landmarks Extraction in Video.')
|
81 |
+
kp_extractor = KeypointExtractor()
|
82 |
+
lm = kp_extractor.extract_keypoint(frames_pil, './temp/'+base_name+'_landmarks.txt')
|
83 |
+
else:
|
84 |
+
print('[Step 1] Using saved landmarks.')
|
85 |
+
lm = np.loadtxt('temp/'+base_name+'_landmarks.txt').astype(np.float32)
|
86 |
+
lm = lm.reshape([len(full_frames), -1, 2])
|
87 |
+
|
88 |
+
if not os.path.isfile('temp/'+base_name+'_coeffs.npy') or args.exp_img is not None or args.re_preprocess:
|
89 |
+
net_recon = load_face3d_net(args.face3d_net_path, device)
|
90 |
+
lm3d_std = load_lm3d('checkpoints/BFM')
|
91 |
+
|
92 |
+
video_coeffs = []
|
93 |
+
for idx in tqdm(range(len(frames_pil)), desc="[Step 2] 3DMM Extraction In Video:"):
|
94 |
+
frame = frames_pil[idx]
|
95 |
+
W, H = frame.size
|
96 |
+
lm_idx = lm[idx].reshape([-1, 2])
|
97 |
+
if np.mean(lm_idx) == -1:
|
98 |
+
lm_idx = (lm3d_std[:, :2]+1) / 2.
|
99 |
+
lm_idx = np.concatenate([lm_idx[:, :1] * W, lm_idx[:, 1:2] * H], 1)
|
100 |
+
else:
|
101 |
+
lm_idx[:, -1] = H - 1 - lm_idx[:, -1]
|
102 |
+
|
103 |
+
trans_params, im_idx, lm_idx, _ = align_img(frame, lm_idx, lm3d_std)
|
104 |
+
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
|
105 |
+
im_idx_tensor = torch.tensor(np.array(im_idx)/255., dtype=torch.float32).permute(2, 0, 1).to(device).unsqueeze(0)
|
106 |
+
with torch.no_grad():
|
107 |
+
coeffs = split_coeff(net_recon(im_idx_tensor))
|
108 |
+
|
109 |
+
pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
|
110 |
+
pred_coeff = np.concatenate([pred_coeff['id'], pred_coeff['exp'], pred_coeff['tex'], pred_coeff['angle'],\
|
111 |
+
pred_coeff['gamma'], pred_coeff['trans'], trans_params[None]], 1)
|
112 |
+
video_coeffs.append(pred_coeff)
|
113 |
+
semantic_npy = np.array(video_coeffs)[:,0]
|
114 |
+
np.save('temp/'+base_name+'_coeffs.npy', semantic_npy)
|
115 |
+
else:
|
116 |
+
print('[Step 2] Using saved coeffs.')
|
117 |
+
semantic_npy = np.load('temp/'+base_name+'_coeffs.npy').astype(np.float32)
|
118 |
+
|
119 |
+
# generate the 3dmm coeff from a single image
|
120 |
+
if args.exp_img is not None and ('.png' in args.exp_img or '.jpg' in args.exp_img):
|
121 |
+
print('extract the exp from',args.exp_img)
|
122 |
+
exp_pil = Image.open(args.exp_img).convert('RGB')
|
123 |
+
lm3d_std = load_lm3d('third_part/face3d/BFM')
|
124 |
+
|
125 |
+
W, H = exp_pil.size
|
126 |
+
kp_extractor = KeypointExtractor()
|
127 |
+
lm_exp = kp_extractor.extract_keypoint([exp_pil], 'temp/'+base_name+'_temp.txt')[0]
|
128 |
+
if np.mean(lm_exp) == -1:
|
129 |
+
lm_exp = (lm3d_std[:, :2] + 1) / 2.
|
130 |
+
lm_exp = np.concatenate(
|
131 |
+
[lm_exp[:, :1] * W, lm_exp[:, 1:2] * H], 1)
|
132 |
+
else:
|
133 |
+
lm_exp[:, -1] = H - 1 - lm_exp[:, -1]
|
134 |
+
|
135 |
+
trans_params, im_exp, lm_exp, _ = align_img(exp_pil, lm_exp, lm3d_std)
|
136 |
+
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
|
137 |
+
im_exp_tensor = torch.tensor(np.array(im_exp)/255., dtype=torch.float32).permute(2, 0, 1).to(device).unsqueeze(0)
|
138 |
+
with torch.no_grad():
|
139 |
+
expression = split_coeff(net_recon(im_exp_tensor))['exp'][0]
|
140 |
+
del net_recon
|
141 |
+
elif args.exp_img == 'smile':
|
142 |
+
expression = torch.tensor(loadmat('checkpoints/expression.mat')['expression_mouth'])[0]
|
143 |
+
else:
|
144 |
+
print('using expression center')
|
145 |
+
expression = torch.tensor(loadmat('checkpoints/expression.mat')['expression_center'])[0]
|
146 |
+
|
147 |
+
# load DNet, model(LNet and ENet)
|
148 |
+
D_Net, model = load_model(args, device)
|
149 |
+
|
150 |
+
if not os.path.isfile('temp/'+base_name+'_stablized.npy') or args.re_preprocess:
|
151 |
+
imgs = []
|
152 |
+
for idx in tqdm(range(len(frames_pil)), desc="[Step 3] Stablize the expression In Video:"):
|
153 |
+
if args.one_shot:
|
154 |
+
source_img = trans_image(frames_pil[0]).unsqueeze(0).to(device)
|
155 |
+
semantic_source_numpy = semantic_npy[0:1]
|
156 |
+
else:
|
157 |
+
source_img = trans_image(frames_pil[idx]).unsqueeze(0).to(device)
|
158 |
+
semantic_source_numpy = semantic_npy[idx:idx+1]
|
159 |
+
ratio = find_crop_norm_ratio(semantic_source_numpy, semantic_npy)
|
160 |
+
coeff = transform_semantic(semantic_npy, idx, ratio).unsqueeze(0).to(device)
|
161 |
+
|
162 |
+
# hacking the new expression
|
163 |
+
coeff[:, :64, :] = expression[None, :64, None].to(device)
|
164 |
+
with torch.no_grad():
|
165 |
+
output = D_Net(source_img, coeff)
|
166 |
+
img_stablized = np.uint8((output['fake_image'].squeeze(0).permute(1,2,0).cpu().clamp_(-1, 1).numpy() + 1 )/2. * 255)
|
167 |
+
imgs.append(cv2.cvtColor(img_stablized,cv2.COLOR_RGB2BGR))
|
168 |
+
np.save('temp/'+base_name+'_stablized.npy',imgs)
|
169 |
+
del D_Net
|
170 |
+
else:
|
171 |
+
print('[Step 3] Using saved stablized video.')
|
172 |
+
imgs = np.load('temp/'+base_name+'_stablized.npy')
|
173 |
+
torch.cuda.empty_cache()
|
174 |
+
|
175 |
+
if not args.audio.endswith('.wav'):
|
176 |
+
command = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(args.audio, 'temp/{}/temp.wav'.format(args.tmp_dir))
|
177 |
+
subprocess.call(command, shell=True)
|
178 |
+
args.audio = 'temp/{}/temp.wav'.format(args.tmp_dir)
|
179 |
+
wav = audio.load_wav(args.audio, 16000)
|
180 |
+
mel = audio.melspectrogram(wav)
|
181 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
182 |
+
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
|
183 |
+
|
184 |
+
mel_step_size, mel_idx_multiplier, i, mel_chunks = 16, 80./fps, 0, []
|
185 |
+
while True:
|
186 |
+
start_idx = int(i * mel_idx_multiplier)
|
187 |
+
if start_idx + mel_step_size > len(mel[0]):
|
188 |
+
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
189 |
+
break
|
190 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
191 |
+
i += 1
|
192 |
+
|
193 |
+
print("[Step 4] Load audio; Length of mel chunks: {}".format(len(mel_chunks)))
|
194 |
+
imgs = imgs[:len(mel_chunks)]
|
195 |
+
full_frames = full_frames[:len(mel_chunks)]
|
196 |
+
lm = lm[:len(mel_chunks)]
|
197 |
+
|
198 |
+
imgs_enhanced = []
|
199 |
+
for idx in tqdm(range(len(imgs)), desc='[Step 5] Reference Enhancement'):
|
200 |
+
img = imgs[idx]
|
201 |
+
pred, _, _ = enhancer.process(img, img, face_enhance=True, possion_blending=False)
|
202 |
+
imgs_enhanced.append(pred)
|
203 |
+
gen = datagen(imgs_enhanced.copy(), mel_chunks, full_frames, None, (oy1,oy2,ox1,ox2))
|
204 |
+
|
205 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
206 |
+
out = cv2.VideoWriter('temp/{}/result.mp4'.format(args.tmp_dir), cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w, frame_h))
|
207 |
+
|
208 |
+
if args.up_face != 'original':
|
209 |
+
instance = GANimationModel()
|
210 |
+
instance.initialize()
|
211 |
+
instance.setup()
|
212 |
+
|
213 |
+
kp_extractor = KeypointExtractor()
|
214 |
+
for i, (img_batch, mel_batch, frames, coords, img_original, f_frames) in enumerate(tqdm(gen, desc='[Step 6] Lip Synthesis:', total=int(np.ceil(float(len(mel_chunks)) / args.LNet_batch_size)))):
|
215 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
216 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
217 |
+
img_original = torch.FloatTensor(np.transpose(img_original, (0, 3, 1, 2))).to(device)/255. # BGR -> RGB
|
218 |
+
|
219 |
+
with torch.no_grad():
|
220 |
+
incomplete, reference = torch.split(img_batch, 3, dim=1)
|
221 |
+
pred, low_res = model(mel_batch, img_batch, reference)
|
222 |
+
pred = torch.clamp(pred, 0, 1)
|
223 |
+
|
224 |
+
if args.up_face in ['sad', 'angry', 'surprise']:
|
225 |
+
tar_aus = exp_aus_dict[args.up_face]
|
226 |
+
else:
|
227 |
+
pass
|
228 |
+
|
229 |
+
if args.up_face == 'original':
|
230 |
+
cur_gen_faces = img_original
|
231 |
+
else:
|
232 |
+
test_batch = {'src_img': torch.nn.functional.interpolate((img_original * 2 - 1), size=(128, 128), mode='bilinear'),
|
233 |
+
'tar_aus': tar_aus.repeat(len(incomplete), 1)}
|
234 |
+
instance.feed_batch(test_batch)
|
235 |
+
instance.forward()
|
236 |
+
cur_gen_faces = torch.nn.functional.interpolate(instance.fake_img / 2. + 0.5, size=(384, 384), mode='bilinear')
|
237 |
+
|
238 |
+
if args.without_rl1 is not False:
|
239 |
+
incomplete, reference = torch.split(img_batch, 3, dim=1)
|
240 |
+
mask = torch.where(incomplete==0, torch.ones_like(incomplete), torch.zeros_like(incomplete))
|
241 |
+
pred = pred * mask + cur_gen_faces * (1 - mask)
|
242 |
+
|
243 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
244 |
+
|
245 |
+
torch.cuda.empty_cache()
|
246 |
+
for p, f, xf, c in zip(pred, frames, f_frames, coords):
|
247 |
+
y1, y2, x1, x2 = c
|
248 |
+
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
|
249 |
+
|
250 |
+
ff = xf.copy()
|
251 |
+
ff[y1:y2, x1:x2] = p
|
252 |
+
|
253 |
+
# month region enhancement by GFPGAN
|
254 |
+
cropped_faces, restored_faces, restored_img = restorer.enhance(
|
255 |
+
ff, has_aligned=False, only_center_face=True, paste_back=True)
|
256 |
+
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
257 |
+
mm = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0]
|
258 |
+
mouse_mask = np.zeros_like(restored_img)
|
259 |
+
tmp_mask = enhancer.faceparser.process(restored_img[y1:y2, x1:x2], mm)[0]
|
260 |
+
mouse_mask[y1:y2, x1:x2]= cv2.resize(tmp_mask, (x2 - x1, y2 - y1))[:, :, np.newaxis] / 255.
|
261 |
+
|
262 |
+
height, width = ff.shape[:2]
|
263 |
+
restored_img, ff, full_mask = [cv2.resize(x, (512, 512)) for x in (restored_img, ff, np.float32(mouse_mask))]
|
264 |
+
img = Laplacian_Pyramid_Blending_with_mask(restored_img, ff, full_mask[:, :, 0], 10)
|
265 |
+
pp = np.uint8(cv2.resize(np.clip(img, 0 ,255), (width, height)))
|
266 |
+
|
267 |
+
pp, orig_faces, enhanced_faces = enhancer.process(pp, xf, bbox=c, face_enhance=False, possion_blending=True)
|
268 |
+
out.write(pp)
|
269 |
+
out.release()
|
270 |
+
|
271 |
+
if not os.path.isdir(os.path.dirname(args.outfile)):
|
272 |
+
os.makedirs(os.path.dirname(args.outfile), exist_ok=True)
|
273 |
+
command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/{}/result.mp4'.format(args.tmp_dir), args.outfile)
|
274 |
+
subprocess.call(command, shell=platform.system() != 'Windows')
|
275 |
+
print('outfile:', args.outfile)
|
276 |
+
|
277 |
+
|
278 |
+
# frames:256x256, full_frames: original size
|
279 |
+
def datagen(frames, mels, full_frames, frames_pil, cox):
|
280 |
+
img_batch, mel_batch, frame_batch, coords_batch, ref_batch, full_frame_batch = [], [], [], [], [], []
|
281 |
+
base_name = args.face.split('/')[-1]
|
282 |
+
refs = []
|
283 |
+
image_size = 256
|
284 |
+
|
285 |
+
# original frames
|
286 |
+
kp_extractor = KeypointExtractor()
|
287 |
+
fr_pil = [Image.fromarray(frame) for frame in frames]
|
288 |
+
lms = kp_extractor.extract_keypoint(fr_pil, 'temp/'+base_name+'x12_landmarks.txt')
|
289 |
+
frames_pil = [ (lm, frame) for frame,lm in zip(fr_pil, lms)] # frames is the croped version of modified face
|
290 |
+
crops, orig_images, quads = crop_faces(image_size, frames_pil, scale=1.0, use_fa=True)
|
291 |
+
inverse_transforms = [calc_alignment_coefficients(quad + 0.5, [[0, 0], [0, image_size], [image_size, image_size], [image_size, 0]]) for quad in quads]
|
292 |
+
del kp_extractor.detector
|
293 |
+
|
294 |
+
oy1,oy2,ox1,ox2 = cox
|
295 |
+
face_det_results = face_detect(full_frames, args, jaw_correction=True)
|
296 |
+
|
297 |
+
for inverse_transform, crop, full_frame, face_det in zip(inverse_transforms, crops, full_frames, face_det_results):
|
298 |
+
imc_pil = paste_image(inverse_transform, crop, Image.fromarray(
|
299 |
+
cv2.resize(full_frame[int(oy1):int(oy2), int(ox1):int(ox2)], (256, 256))))
|
300 |
+
|
301 |
+
ff = full_frame.copy()
|
302 |
+
ff[int(oy1):int(oy2), int(ox1):int(ox2)] = cv2.resize(np.array(imc_pil.convert('RGB')), (ox2 - ox1, oy2 - oy1))
|
303 |
+
oface, coords = face_det
|
304 |
+
y1, y2, x1, x2 = coords
|
305 |
+
refs.append(ff[y1: y2, x1:x2])
|
306 |
+
|
307 |
+
for i, m in enumerate(mels):
|
308 |
+
idx = 0 if args.static else i % len(frames)
|
309 |
+
frame_to_save = frames[idx].copy()
|
310 |
+
face = refs[idx]
|
311 |
+
oface, coords = face_det_results[idx].copy()
|
312 |
+
|
313 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
314 |
+
oface = cv2.resize(oface, (args.img_size, args.img_size))
|
315 |
+
|
316 |
+
img_batch.append(oface)
|
317 |
+
ref_batch.append(face)
|
318 |
+
mel_batch.append(m)
|
319 |
+
coords_batch.append(coords)
|
320 |
+
frame_batch.append(frame_to_save)
|
321 |
+
full_frame_batch.append(full_frames[idx].copy())
|
322 |
+
|
323 |
+
if len(img_batch) >= args.LNet_batch_size:
|
324 |
+
img_batch, mel_batch, ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
|
325 |
+
img_masked = img_batch.copy()
|
326 |
+
img_original = img_batch.copy()
|
327 |
+
img_masked[:, args.img_size//2:] = 0
|
328 |
+
img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
|
329 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
330 |
+
|
331 |
+
yield img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch
|
332 |
+
img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch, ref_batch = [], [], [], [], [], [], []
|
333 |
+
|
334 |
+
if len(img_batch) > 0:
|
335 |
+
img_batch, mel_batch, ref_batch = np.asarray(img_batch), np.asarray(mel_batch), np.asarray(ref_batch)
|
336 |
+
img_masked = img_batch.copy()
|
337 |
+
img_original = img_batch.copy()
|
338 |
+
img_masked[:, args.img_size//2:] = 0
|
339 |
+
img_batch = np.concatenate((img_masked, ref_batch), axis=3) / 255.
|
340 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
341 |
+
yield img_batch, mel_batch, frame_batch, coords_batch, img_original, full_frame_batch
|
342 |
+
|
343 |
+
|
344 |
+
if __name__ == '__main__':
|
345 |
+
main()
|
models/DNet.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO
|
2 |
+
import functools
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from utils import flow_util
|
10 |
+
from models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
|
11 |
+
|
12 |
+
# DNet
|
13 |
+
class DNet(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(DNet, self).__init__()
|
16 |
+
self.mapping_net = MappingNet()
|
17 |
+
self.warpping_net = WarpingNet()
|
18 |
+
self.editing_net = EditingNet()
|
19 |
+
|
20 |
+
def forward(self, input_image, driving_source, stage=None):
|
21 |
+
if stage == 'warp':
|
22 |
+
descriptor = self.mapping_net(driving_source)
|
23 |
+
output = self.warpping_net(input_image, descriptor)
|
24 |
+
else:
|
25 |
+
descriptor = self.mapping_net(driving_source)
|
26 |
+
output = self.warpping_net(input_image, descriptor)
|
27 |
+
output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
|
28 |
+
return output
|
29 |
+
|
30 |
+
class MappingNet(nn.Module):
|
31 |
+
def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3):
|
32 |
+
super( MappingNet, self).__init__()
|
33 |
+
|
34 |
+
self.layer = layer
|
35 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
36 |
+
|
37 |
+
self.first = nn.Sequential(
|
38 |
+
torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
|
39 |
+
|
40 |
+
for i in range(layer):
|
41 |
+
net = nn.Sequential(nonlinearity,
|
42 |
+
torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
|
43 |
+
setattr(self, 'encoder' + str(i), net)
|
44 |
+
|
45 |
+
self.pooling = nn.AdaptiveAvgPool1d(1)
|
46 |
+
self.output_nc = descriptor_nc
|
47 |
+
|
48 |
+
def forward(self, input_3dmm):
|
49 |
+
out = self.first(input_3dmm)
|
50 |
+
for i in range(self.layer):
|
51 |
+
model = getattr(self, 'encoder' + str(i))
|
52 |
+
out = model(out) + out[:,:,3:-3]
|
53 |
+
out = self.pooling(out)
|
54 |
+
return out
|
55 |
+
|
56 |
+
class WarpingNet(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
image_nc=3,
|
60 |
+
descriptor_nc=256,
|
61 |
+
base_nc=32,
|
62 |
+
max_nc=256,
|
63 |
+
encoder_layer=5,
|
64 |
+
decoder_layer=3,
|
65 |
+
use_spect=False
|
66 |
+
):
|
67 |
+
super( WarpingNet, self).__init__()
|
68 |
+
|
69 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
70 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
71 |
+
kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
72 |
+
|
73 |
+
self.descriptor_nc = descriptor_nc
|
74 |
+
self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
|
75 |
+
max_nc, encoder_layer, decoder_layer, **kwargs)
|
76 |
+
|
77 |
+
self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
|
78 |
+
nonlinearity,
|
79 |
+
nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
|
80 |
+
|
81 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
82 |
+
|
83 |
+
def forward(self, input_image, descriptor):
|
84 |
+
final_output={}
|
85 |
+
output = self.hourglass(input_image, descriptor)
|
86 |
+
final_output['flow_field'] = self.flow_out(output)
|
87 |
+
|
88 |
+
deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
|
89 |
+
final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
|
90 |
+
return final_output
|
91 |
+
|
92 |
+
|
93 |
+
class EditingNet(nn.Module):
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
image_nc=3,
|
97 |
+
descriptor_nc=256,
|
98 |
+
layer=3,
|
99 |
+
base_nc=64,
|
100 |
+
max_nc=256,
|
101 |
+
num_res_blocks=2,
|
102 |
+
use_spect=False):
|
103 |
+
super(EditingNet, self).__init__()
|
104 |
+
|
105 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
106 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
107 |
+
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
108 |
+
self.descriptor_nc = descriptor_nc
|
109 |
+
|
110 |
+
# encoder part
|
111 |
+
self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
|
112 |
+
self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
|
113 |
+
|
114 |
+
def forward(self, input_image, warp_image, descriptor):
|
115 |
+
x = torch.cat([input_image, warp_image], 1)
|
116 |
+
x = self.encoder(x)
|
117 |
+
gen_image = self.decoder(x, descriptor)
|
118 |
+
return gen_image
|
models/ENet.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from models.base_blocks import ResBlock, StyleConv, ToRGB
|
6 |
+
|
7 |
+
|
8 |
+
class ENet(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
num_style_feat=512,
|
12 |
+
lnet=None,
|
13 |
+
concat=False
|
14 |
+
):
|
15 |
+
super(ENet, self).__init__()
|
16 |
+
|
17 |
+
self.low_res = lnet
|
18 |
+
for param in self.low_res.parameters():
|
19 |
+
param.requires_grad = False
|
20 |
+
|
21 |
+
channel_multiplier, narrow = 2, 1
|
22 |
+
channels = {
|
23 |
+
'4': int(512 * narrow),
|
24 |
+
'8': int(512 * narrow),
|
25 |
+
'16': int(512 * narrow),
|
26 |
+
'32': int(512 * narrow),
|
27 |
+
'64': int(256 * channel_multiplier * narrow),
|
28 |
+
'128': int(128 * channel_multiplier * narrow),
|
29 |
+
'256': int(64 * channel_multiplier * narrow),
|
30 |
+
'512': int(32 * channel_multiplier * narrow),
|
31 |
+
'1024': int(16 * channel_multiplier * narrow)
|
32 |
+
}
|
33 |
+
|
34 |
+
self.log_size = 8
|
35 |
+
first_out_size = 128
|
36 |
+
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128
|
37 |
+
|
38 |
+
# downsample
|
39 |
+
in_channels = channels[f'{first_out_size}']
|
40 |
+
self.conv_body_down = nn.ModuleList()
|
41 |
+
for i in range(8, 2, -1):
|
42 |
+
out_channels = channels[f'{2**(i - 1)}']
|
43 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
44 |
+
in_channels = out_channels
|
45 |
+
|
46 |
+
self.num_style_feat = num_style_feat
|
47 |
+
linear_out_channel = num_style_feat
|
48 |
+
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
49 |
+
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
50 |
+
|
51 |
+
self.style_convs = nn.ModuleList()
|
52 |
+
self.to_rgbs = nn.ModuleList()
|
53 |
+
self.noises = nn.Module()
|
54 |
+
|
55 |
+
self.concat = concat
|
56 |
+
if concat:
|
57 |
+
in_channels = 3 + 32 # channels['64']
|
58 |
+
else:
|
59 |
+
in_channels = 3
|
60 |
+
|
61 |
+
for i in range(7, 9): # 128, 256
|
62 |
+
out_channels = channels[f'{2**i}'] #
|
63 |
+
self.style_convs.append(
|
64 |
+
StyleConv(
|
65 |
+
in_channels,
|
66 |
+
out_channels,
|
67 |
+
kernel_size=3,
|
68 |
+
num_style_feat=num_style_feat,
|
69 |
+
demodulate=True,
|
70 |
+
sample_mode='upsample'))
|
71 |
+
self.style_convs.append(
|
72 |
+
StyleConv(
|
73 |
+
out_channels,
|
74 |
+
out_channels,
|
75 |
+
kernel_size=3,
|
76 |
+
num_style_feat=num_style_feat,
|
77 |
+
demodulate=True,
|
78 |
+
sample_mode=None))
|
79 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
80 |
+
in_channels = out_channels
|
81 |
+
|
82 |
+
def forward(self, audio_sequences, face_sequences, gt_sequences):
|
83 |
+
B = audio_sequences.size(0)
|
84 |
+
input_dim_size = len(face_sequences.size())
|
85 |
+
inp, ref = torch.split(face_sequences,3,dim=1)
|
86 |
+
|
87 |
+
if input_dim_size > 4:
|
88 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
89 |
+
inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0)
|
90 |
+
ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0)
|
91 |
+
gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0)
|
92 |
+
|
93 |
+
# get the global style
|
94 |
+
feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2)
|
95 |
+
for i in range(self.log_size - 2):
|
96 |
+
feat = self.conv_body_down[i](feat)
|
97 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
98 |
+
|
99 |
+
# style code
|
100 |
+
style_code = self.final_linear(feat.reshape(feat.size(0), -1))
|
101 |
+
style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat)
|
102 |
+
|
103 |
+
LNet_input = torch.cat([inp, gt_sequences], dim=1)
|
104 |
+
LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear')
|
105 |
+
|
106 |
+
if self.concat:
|
107 |
+
low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input)
|
108 |
+
low_res_img.detach()
|
109 |
+
low_res_feat.detach()
|
110 |
+
out = torch.cat([low_res_img, low_res_feat], dim=1)
|
111 |
+
|
112 |
+
else:
|
113 |
+
low_res_img = self.low_res(audio_sequences, LNet_input)
|
114 |
+
low_res_img.detach()
|
115 |
+
# 96 x 96
|
116 |
+
out = low_res_img
|
117 |
+
|
118 |
+
p2d = (2,2,2,2)
|
119 |
+
out = F.pad(out, p2d, "reflect", 0)
|
120 |
+
skip = out
|
121 |
+
|
122 |
+
for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs):
|
123 |
+
out = conv1(out, style_code) # 96, 192, 384
|
124 |
+
out = conv2(out, style_code)
|
125 |
+
skip = to_rgb(out, style_code, skip)
|
126 |
+
_outputs = skip
|
127 |
+
|
128 |
+
# remove padding
|
129 |
+
_outputs = _outputs[:,:,8:-8,8:-8]
|
130 |
+
|
131 |
+
if input_dim_size > 4:
|
132 |
+
_outputs = torch.split(_outputs, B, dim=0)
|
133 |
+
outputs = torch.stack(_outputs, dim=2)
|
134 |
+
low_res_img = F.interpolate(low_res_img, outputs.size()[3:])
|
135 |
+
low_res_img = torch.split(low_res_img, B, dim=0)
|
136 |
+
low_res_img = torch.stack(low_res_img, dim=2)
|
137 |
+
else:
|
138 |
+
outputs = _outputs
|
139 |
+
return outputs, low_res_img
|
models/LNet.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from models.transformer import RETURNX, Transformer
|
6 |
+
from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
|
7 |
+
FFCADAINResBlocks, Jump, FinalBlock2d
|
8 |
+
|
9 |
+
|
10 |
+
class Visual_Encoder(nn.Module):
|
11 |
+
def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
12 |
+
super(Visual_Encoder, self).__init__()
|
13 |
+
self.layers = layers
|
14 |
+
self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
15 |
+
self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
16 |
+
for i in range(layers):
|
17 |
+
in_channels = min(ngf*(2**i), img_f)
|
18 |
+
out_channels = min(ngf*(2**(i+1)), img_f)
|
19 |
+
model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
20 |
+
model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
21 |
+
if i < 2:
|
22 |
+
ca_layer = RETURNX()
|
23 |
+
else:
|
24 |
+
ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
|
25 |
+
setattr(self, 'ca' + str(i), ca_layer)
|
26 |
+
setattr(self, 'ref_down' + str(i), model_ref)
|
27 |
+
setattr(self, 'inp_down' + str(i), model_inp)
|
28 |
+
self.output_nc = out_channels * 2
|
29 |
+
|
30 |
+
def forward(self, maskGT, ref):
|
31 |
+
x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
|
32 |
+
out=[x_maskGT]
|
33 |
+
for i in range(self.layers):
|
34 |
+
model_ref = getattr(self, 'ref_down'+str(i))
|
35 |
+
model_inp = getattr(self, 'inp_down'+str(i))
|
36 |
+
ca_layer = getattr(self, 'ca'+str(i))
|
37 |
+
x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
|
38 |
+
x_maskGT = ca_layer(x_maskGT, x_ref)
|
39 |
+
if i < self.layers - 1:
|
40 |
+
out.append(x_maskGT)
|
41 |
+
else:
|
42 |
+
out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class Decoder(nn.Module):
|
47 |
+
def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
48 |
+
super(Decoder, self).__init__()
|
49 |
+
self.layers = layers
|
50 |
+
for i in range(layers)[::-1]:
|
51 |
+
if i == layers-1:
|
52 |
+
in_channels = ngf*(2**(i+1)) * 2
|
53 |
+
else:
|
54 |
+
in_channels = min(ngf*(2**(i+1)), img_f)
|
55 |
+
out_channels = min(ngf*(2**i), img_f)
|
56 |
+
up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
57 |
+
res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
|
58 |
+
jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
|
59 |
+
|
60 |
+
setattr(self, 'up' + str(i), up)
|
61 |
+
setattr(self, 'res' + str(i), res)
|
62 |
+
setattr(self, 'jump' + str(i), jump)
|
63 |
+
|
64 |
+
self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
|
65 |
+
self.output_nc = out_channels
|
66 |
+
|
67 |
+
def forward(self, x, z):
|
68 |
+
out = x.pop()
|
69 |
+
for i in range(self.layers)[::-1]:
|
70 |
+
res_model = getattr(self, 'res' + str(i))
|
71 |
+
up_model = getattr(self, 'up' + str(i))
|
72 |
+
jump_model = getattr(self, 'jump' + str(i))
|
73 |
+
out = res_model(out, z)
|
74 |
+
out = up_model(out)
|
75 |
+
out = jump_model(x.pop()) + out
|
76 |
+
out_image = self.final(out)
|
77 |
+
return out_image
|
78 |
+
|
79 |
+
|
80 |
+
class LNet(nn.Module):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
image_nc=3,
|
84 |
+
descriptor_nc=512,
|
85 |
+
layer=3,
|
86 |
+
base_nc=64,
|
87 |
+
max_nc=512,
|
88 |
+
num_res_blocks=9,
|
89 |
+
use_spect=True,
|
90 |
+
encoder=Visual_Encoder,
|
91 |
+
decoder=Decoder
|
92 |
+
):
|
93 |
+
super(LNet, self).__init__()
|
94 |
+
|
95 |
+
nonlinearity = nn.LeakyReLU(0.1)
|
96 |
+
norm_layer = functools.partial(LayerNorm2d, affine=True)
|
97 |
+
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
|
98 |
+
self.descriptor_nc = descriptor_nc
|
99 |
+
|
100 |
+
self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
|
101 |
+
self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
|
102 |
+
self.audio_encoder = nn.Sequential(
|
103 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
104 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
105 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
106 |
+
|
107 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
108 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
109 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
110 |
+
|
111 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
112 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
113 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
114 |
+
|
115 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
116 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
117 |
+
|
118 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
119 |
+
Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(self, audio_sequences, face_sequences):
|
123 |
+
B = audio_sequences.size(0)
|
124 |
+
input_dim_size = len(face_sequences.size())
|
125 |
+
if input_dim_size > 4:
|
126 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
127 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
128 |
+
cropped, ref = torch.split(face_sequences, 3, dim=1)
|
129 |
+
|
130 |
+
vis_feat = self.encoder(cropped, ref)
|
131 |
+
audio_feat = self.audio_encoder(audio_sequences)
|
132 |
+
_outputs = self.decoder(vis_feat, audio_feat)
|
133 |
+
|
134 |
+
if input_dim_size > 4:
|
135 |
+
_outputs = torch.split(_outputs, B, dim=0)
|
136 |
+
outputs = torch.stack(_outputs, dim=2)
|
137 |
+
else:
|
138 |
+
outputs = _outputs
|
139 |
+
return outputs
|
models/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.DNet import DNet
|
3 |
+
from models.LNet import LNet
|
4 |
+
from models.ENet import ENet
|
5 |
+
|
6 |
+
|
7 |
+
def _load(checkpoint_path):
|
8 |
+
checkpoint = torch.load(checkpoint_path)
|
9 |
+
return checkpoint
|
10 |
+
|
11 |
+
def load_checkpoint(path, model):
|
12 |
+
print("Load checkpoint from: {}".format(path))
|
13 |
+
checkpoint = _load(path)
|
14 |
+
s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
|
15 |
+
new_s = {}
|
16 |
+
for k, v in s.items():
|
17 |
+
if 'low_res' in k:
|
18 |
+
continue
|
19 |
+
else:
|
20 |
+
new_s[k.replace('module.', '')] = v
|
21 |
+
model.load_state_dict(new_s, strict=False)
|
22 |
+
return model
|
23 |
+
|
24 |
+
def load_network(args):
|
25 |
+
L_net = LNet()
|
26 |
+
L_net = load_checkpoint(args.LNet_path, L_net)
|
27 |
+
E_net = ENet(lnet=L_net)
|
28 |
+
model = load_checkpoint(args.ENet_path, E_net)
|
29 |
+
return model.eval()
|
30 |
+
|
31 |
+
def load_DNet(args):
|
32 |
+
D_Net = DNet()
|
33 |
+
print("Load checkpoint from: {}".format(args.DNet_path))
|
34 |
+
checkpoint = torch.load(args.DNet_path, map_location=lambda storage, loc: storage)
|
35 |
+
D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
|
36 |
+
return D_Net.eval()
|
models/base_blocks.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
6 |
+
from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
|
7 |
+
|
8 |
+
from models.ffc import FFC
|
9 |
+
from basicsr.archs.arch_util import default_init_weights
|
10 |
+
|
11 |
+
|
12 |
+
class Conv2d(nn.Module):
|
13 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
self.conv_block = nn.Sequential(
|
16 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
17 |
+
nn.BatchNorm2d(cout)
|
18 |
+
)
|
19 |
+
self.act = nn.ReLU()
|
20 |
+
self.residual = residual
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
out = self.conv_block(x)
|
24 |
+
if self.residual:
|
25 |
+
out += x
|
26 |
+
return self.act(out)
|
27 |
+
|
28 |
+
|
29 |
+
class ResBlock(nn.Module):
|
30 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
31 |
+
super(ResBlock, self).__init__()
|
32 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
33 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
34 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
35 |
+
if mode == 'down':
|
36 |
+
self.scale_factor = 0.5
|
37 |
+
elif mode == 'up':
|
38 |
+
self.scale_factor = 2
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
42 |
+
# upsample/downsample
|
43 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
44 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
45 |
+
# skip
|
46 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
47 |
+
skip = self.skip(x)
|
48 |
+
out = out + skip
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class LayerNorm2d(nn.Module):
|
53 |
+
def __init__(self, n_out, affine=True):
|
54 |
+
super(LayerNorm2d, self).__init__()
|
55 |
+
self.n_out = n_out
|
56 |
+
self.affine = affine
|
57 |
+
|
58 |
+
if self.affine:
|
59 |
+
self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
|
60 |
+
self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
normalized_shape = x.size()[1:]
|
64 |
+
if self.affine:
|
65 |
+
return F.layer_norm(x, normalized_shape, \
|
66 |
+
self.weight.expand(normalized_shape),
|
67 |
+
self.bias.expand(normalized_shape))
|
68 |
+
else:
|
69 |
+
return F.layer_norm(x, normalized_shape)
|
70 |
+
|
71 |
+
|
72 |
+
def spectral_norm(module, use_spect=True):
|
73 |
+
if use_spect:
|
74 |
+
return SpectralNorm(module)
|
75 |
+
else:
|
76 |
+
return module
|
77 |
+
|
78 |
+
|
79 |
+
class FirstBlock2d(nn.Module):
|
80 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
81 |
+
super(FirstBlock2d, self).__init__()
|
82 |
+
kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
|
83 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
84 |
+
|
85 |
+
if type(norm_layer) == type(None):
|
86 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
87 |
+
else:
|
88 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
out = self.model(x)
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class DownBlock2d(nn.Module):
|
96 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
97 |
+
super(DownBlock2d, self).__init__()
|
98 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
99 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
100 |
+
pool = nn.AvgPool2d(kernel_size=(2, 2))
|
101 |
+
|
102 |
+
if type(norm_layer) == type(None):
|
103 |
+
self.model = nn.Sequential(conv, nonlinearity, pool)
|
104 |
+
else:
|
105 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
out = self.model(x)
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
class UpBlock2d(nn.Module):
|
113 |
+
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
114 |
+
super(UpBlock2d, self).__init__()
|
115 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
116 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
117 |
+
if type(norm_layer) == type(None):
|
118 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
119 |
+
else:
|
120 |
+
self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
out = self.model(F.interpolate(x, scale_factor=2))
|
124 |
+
return out
|
125 |
+
|
126 |
+
|
127 |
+
class ADAIN(nn.Module):
|
128 |
+
def __init__(self, norm_nc, feature_nc):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
132 |
+
|
133 |
+
nhidden = 128
|
134 |
+
use_bias=True
|
135 |
+
|
136 |
+
self.mlp_shared = nn.Sequential(
|
137 |
+
nn.Linear(feature_nc, nhidden, bias=use_bias),
|
138 |
+
nn.ReLU()
|
139 |
+
)
|
140 |
+
self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
|
141 |
+
self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
|
142 |
+
|
143 |
+
def forward(self, x, feature):
|
144 |
+
|
145 |
+
# Part 1. generate parameter-free normalized activations
|
146 |
+
normalized = self.param_free_norm(x)
|
147 |
+
# Part 2. produce scaling and bias conditioned on feature
|
148 |
+
feature = feature.view(feature.size(0), -1)
|
149 |
+
actv = self.mlp_shared(feature)
|
150 |
+
gamma = self.mlp_gamma(actv)
|
151 |
+
beta = self.mlp_beta(actv)
|
152 |
+
|
153 |
+
# apply scale and bias
|
154 |
+
gamma = gamma.view(*gamma.size()[:2], 1,1)
|
155 |
+
beta = beta.view(*beta.size()[:2], 1,1)
|
156 |
+
out = normalized * (1 + gamma) + beta
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
class FineADAINResBlock2d(nn.Module):
|
161 |
+
"""
|
162 |
+
Define an Residual block for different types
|
163 |
+
"""
|
164 |
+
def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
165 |
+
super(FineADAINResBlock2d, self).__init__()
|
166 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
167 |
+
self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
168 |
+
self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
169 |
+
self.norm1 = ADAIN(input_nc, feature_nc)
|
170 |
+
self.norm2 = ADAIN(input_nc, feature_nc)
|
171 |
+
self.actvn = nonlinearity
|
172 |
+
|
173 |
+
def forward(self, x, z):
|
174 |
+
dx = self.actvn(self.norm1(self.conv1(x), z))
|
175 |
+
dx = self.norm2(self.conv2(x), z)
|
176 |
+
out = dx + x
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
class FineADAINResBlocks(nn.Module):
|
181 |
+
def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
182 |
+
super(FineADAINResBlocks, self).__init__()
|
183 |
+
self.num_block = num_block
|
184 |
+
for i in range(num_block):
|
185 |
+
model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
|
186 |
+
setattr(self, 'res'+str(i), model)
|
187 |
+
|
188 |
+
def forward(self, x, z):
|
189 |
+
for i in range(self.num_block):
|
190 |
+
model = getattr(self, 'res'+str(i))
|
191 |
+
x = model(x, z)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class ADAINEncoderBlock(nn.Module):
|
196 |
+
def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
197 |
+
super(ADAINEncoderBlock, self).__init__()
|
198 |
+
kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
|
199 |
+
kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
200 |
+
|
201 |
+
self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
|
202 |
+
self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
|
203 |
+
|
204 |
+
|
205 |
+
self.norm_0 = ADAIN(input_nc, feature_nc)
|
206 |
+
self.norm_1 = ADAIN(output_nc, feature_nc)
|
207 |
+
self.actvn = nonlinearity
|
208 |
+
|
209 |
+
def forward(self, x, z):
|
210 |
+
x = self.conv_0(self.actvn(self.norm_0(x, z)))
|
211 |
+
x = self.conv_1(self.actvn(self.norm_1(x, z)))
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
class ADAINDecoderBlock(nn.Module):
|
216 |
+
def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
217 |
+
super(ADAINDecoderBlock, self).__init__()
|
218 |
+
# Attributes
|
219 |
+
self.actvn = nonlinearity
|
220 |
+
hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
|
221 |
+
|
222 |
+
kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
|
223 |
+
if use_transpose:
|
224 |
+
kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
|
225 |
+
else:
|
226 |
+
kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
|
227 |
+
|
228 |
+
# create conv layers
|
229 |
+
self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
|
230 |
+
if use_transpose:
|
231 |
+
self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
|
232 |
+
self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
|
233 |
+
else:
|
234 |
+
self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
|
235 |
+
nn.Upsample(scale_factor=2))
|
236 |
+
self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
|
237 |
+
nn.Upsample(scale_factor=2))
|
238 |
+
# define normalization layers
|
239 |
+
self.norm_0 = ADAIN(input_nc, feature_nc)
|
240 |
+
self.norm_1 = ADAIN(hidden_nc, feature_nc)
|
241 |
+
self.norm_s = ADAIN(input_nc, feature_nc)
|
242 |
+
|
243 |
+
def forward(self, x, z):
|
244 |
+
x_s = self.shortcut(x, z)
|
245 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, z)))
|
246 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
|
247 |
+
out = x_s + dx
|
248 |
+
return out
|
249 |
+
|
250 |
+
def shortcut(self, x, z):
|
251 |
+
x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
|
252 |
+
return x_s
|
253 |
+
|
254 |
+
|
255 |
+
class FineEncoder(nn.Module):
|
256 |
+
"""docstring for Encoder"""
|
257 |
+
def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
258 |
+
super(FineEncoder, self).__init__()
|
259 |
+
self.layers = layers
|
260 |
+
self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
|
261 |
+
for i in range(layers):
|
262 |
+
in_channels = min(ngf*(2**i), img_f)
|
263 |
+
out_channels = min(ngf*(2**(i+1)), img_f)
|
264 |
+
model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
265 |
+
setattr(self, 'down' + str(i), model)
|
266 |
+
self.output_nc = out_channels
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
x = self.first(x)
|
270 |
+
out=[x]
|
271 |
+
for i in range(self.layers):
|
272 |
+
model = getattr(self, 'down'+str(i))
|
273 |
+
x = model(x)
|
274 |
+
out.append(x)
|
275 |
+
return out
|
276 |
+
|
277 |
+
|
278 |
+
class FineDecoder(nn.Module):
|
279 |
+
"""docstring for FineDecoder"""
|
280 |
+
def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
281 |
+
super(FineDecoder, self).__init__()
|
282 |
+
self.layers = layers
|
283 |
+
for i in range(layers)[::-1]:
|
284 |
+
in_channels = min(ngf*(2**(i+1)), img_f)
|
285 |
+
out_channels = min(ngf*(2**i), img_f)
|
286 |
+
up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
|
287 |
+
res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
|
288 |
+
jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
|
289 |
+
setattr(self, 'up' + str(i), up)
|
290 |
+
setattr(self, 'res' + str(i), res)
|
291 |
+
setattr(self, 'jump' + str(i), jump)
|
292 |
+
self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
|
293 |
+
self.output_nc = out_channels
|
294 |
+
|
295 |
+
def forward(self, x, z):
|
296 |
+
out = x.pop()
|
297 |
+
for i in range(self.layers)[::-1]:
|
298 |
+
res_model = getattr(self, 'res' + str(i))
|
299 |
+
up_model = getattr(self, 'up' + str(i))
|
300 |
+
jump_model = getattr(self, 'jump' + str(i))
|
301 |
+
out = res_model(out, z)
|
302 |
+
out = up_model(out)
|
303 |
+
out = jump_model(x.pop()) + out
|
304 |
+
out_image = self.final(out)
|
305 |
+
return out_image
|
306 |
+
|
307 |
+
|
308 |
+
class ADAINEncoder(nn.Module):
|
309 |
+
def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
310 |
+
super(ADAINEncoder, self).__init__()
|
311 |
+
self.layers = layers
|
312 |
+
self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
|
313 |
+
for i in range(layers):
|
314 |
+
in_channels = min(ngf * (2**i), img_f)
|
315 |
+
out_channels = min(ngf *(2**(i+1)), img_f)
|
316 |
+
model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
|
317 |
+
setattr(self, 'encoder' + str(i), model)
|
318 |
+
self.output_nc = out_channels
|
319 |
+
|
320 |
+
def forward(self, x, z):
|
321 |
+
out = self.input_layer(x)
|
322 |
+
out_list = [out]
|
323 |
+
for i in range(self.layers):
|
324 |
+
model = getattr(self, 'encoder' + str(i))
|
325 |
+
out = model(out, z)
|
326 |
+
out_list.append(out)
|
327 |
+
return out_list
|
328 |
+
|
329 |
+
|
330 |
+
class ADAINDecoder(nn.Module):
|
331 |
+
"""docstring for ADAINDecoder"""
|
332 |
+
def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
|
333 |
+
nonlinearity=nn.LeakyReLU(), use_spect=False):
|
334 |
+
|
335 |
+
super(ADAINDecoder, self).__init__()
|
336 |
+
self.encoder_layers = encoder_layers
|
337 |
+
self.decoder_layers = decoder_layers
|
338 |
+
self.skip_connect = skip_connect
|
339 |
+
use_transpose = True
|
340 |
+
for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
|
341 |
+
in_channels = min(ngf * (2**(i+1)), img_f)
|
342 |
+
in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
|
343 |
+
out_channels = min(ngf * (2**i), img_f)
|
344 |
+
model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
|
345 |
+
setattr(self, 'decoder' + str(i), model)
|
346 |
+
self.output_nc = out_channels*2 if self.skip_connect else out_channels
|
347 |
+
|
348 |
+
def forward(self, x, z):
|
349 |
+
out = x.pop() if self.skip_connect else x
|
350 |
+
for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
|
351 |
+
model = getattr(self, 'decoder' + str(i))
|
352 |
+
out = model(out, z)
|
353 |
+
out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
|
354 |
+
return out
|
355 |
+
|
356 |
+
|
357 |
+
class ADAINHourglass(nn.Module):
|
358 |
+
def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
|
359 |
+
super(ADAINHourglass, self).__init__()
|
360 |
+
self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
|
361 |
+
self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
|
362 |
+
self.output_nc = self.decoder.output_nc
|
363 |
+
|
364 |
+
def forward(self, x, z):
|
365 |
+
return self.decoder(self.encoder(x, z), z)
|
366 |
+
|
367 |
+
|
368 |
+
class FineADAINLama(nn.Module):
|
369 |
+
def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
370 |
+
super(FineADAINLama, self).__init__()
|
371 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
372 |
+
self.actvn = nonlinearity
|
373 |
+
ratio_gin = 0.75
|
374 |
+
ratio_gout = 0.75
|
375 |
+
self.ffc = FFC(input_nc, input_nc, 3,
|
376 |
+
ratio_gin, ratio_gout, 1, 1, 1,
|
377 |
+
1, False, False, padding_type='reflect')
|
378 |
+
global_channels = int(input_nc * ratio_gout)
|
379 |
+
self.bn_l = ADAIN(input_nc - global_channels, feature_nc)
|
380 |
+
self.bn_g = ADAIN(global_channels, feature_nc)
|
381 |
+
|
382 |
+
def forward(self, x, z):
|
383 |
+
x_l, x_g = self.ffc(x)
|
384 |
+
x_l = self.actvn(self.bn_l(x_l,z))
|
385 |
+
x_g = self.actvn(self.bn_g(x_g,z))
|
386 |
+
return x_l, x_g
|
387 |
+
|
388 |
+
|
389 |
+
class FFCResnetBlock(nn.Module):
|
390 |
+
def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1,
|
391 |
+
spatial_transform_kwargs=None, inline=False, **conv_kwargs):
|
392 |
+
super().__init__()
|
393 |
+
self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs)
|
394 |
+
self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs)
|
395 |
+
self.inline = True
|
396 |
+
|
397 |
+
def forward(self, x, z):
|
398 |
+
if self.inline:
|
399 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
400 |
+
else:
|
401 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
402 |
+
|
403 |
+
id_l, id_g = x_l, x_g
|
404 |
+
x_l, x_g = self.conv1((x_l, x_g), z)
|
405 |
+
x_l, x_g = self.conv2((x_l, x_g), z)
|
406 |
+
|
407 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
408 |
+
out = x_l, x_g
|
409 |
+
if self.inline:
|
410 |
+
out = torch.cat(out, dim=1)
|
411 |
+
return out
|
412 |
+
|
413 |
+
|
414 |
+
class FFCADAINResBlocks(nn.Module):
|
415 |
+
def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
416 |
+
super(FFCADAINResBlocks, self).__init__()
|
417 |
+
self.num_block = num_block
|
418 |
+
for i in range(num_block):
|
419 |
+
model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
|
420 |
+
setattr(self, 'res'+str(i), model)
|
421 |
+
|
422 |
+
def forward(self, x, z):
|
423 |
+
for i in range(self.num_block):
|
424 |
+
model = getattr(self, 'res'+str(i))
|
425 |
+
x = model(x, z)
|
426 |
+
return x
|
427 |
+
|
428 |
+
|
429 |
+
class Jump(nn.Module):
|
430 |
+
def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
|
431 |
+
super(Jump, self).__init__()
|
432 |
+
kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
|
433 |
+
conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
|
434 |
+
if type(norm_layer) == type(None):
|
435 |
+
self.model = nn.Sequential(conv, nonlinearity)
|
436 |
+
else:
|
437 |
+
self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
|
438 |
+
|
439 |
+
def forward(self, x):
|
440 |
+
out = self.model(x)
|
441 |
+
return out
|
442 |
+
|
443 |
+
|
444 |
+
class FinalBlock2d(nn.Module):
|
445 |
+
def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
|
446 |
+
super(FinalBlock2d, self).__init__()
|
447 |
+
kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
|
448 |
+
conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
|
449 |
+
if tanh_or_sigmoid == 'sigmoid':
|
450 |
+
out_nonlinearity = nn.Sigmoid()
|
451 |
+
else:
|
452 |
+
out_nonlinearity = nn.Tanh()
|
453 |
+
self.model = nn.Sequential(conv, out_nonlinearity)
|
454 |
+
|
455 |
+
def forward(self, x):
|
456 |
+
out = self.model(x)
|
457 |
+
return out
|
458 |
+
|
459 |
+
|
460 |
+
class ModulatedConv2d(nn.Module):
|
461 |
+
def __init__(self,
|
462 |
+
in_channels,
|
463 |
+
out_channels,
|
464 |
+
kernel_size,
|
465 |
+
num_style_feat,
|
466 |
+
demodulate=True,
|
467 |
+
sample_mode=None,
|
468 |
+
eps=1e-8):
|
469 |
+
super(ModulatedConv2d, self).__init__()
|
470 |
+
self.in_channels = in_channels
|
471 |
+
self.out_channels = out_channels
|
472 |
+
self.kernel_size = kernel_size
|
473 |
+
self.demodulate = demodulate
|
474 |
+
self.sample_mode = sample_mode
|
475 |
+
self.eps = eps
|
476 |
+
|
477 |
+
# modulation inside each modulated conv
|
478 |
+
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
479 |
+
# initialization
|
480 |
+
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
481 |
+
|
482 |
+
self.weight = nn.Parameter(
|
483 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
484 |
+
math.sqrt(in_channels * kernel_size**2))
|
485 |
+
self.padding = kernel_size // 2
|
486 |
+
|
487 |
+
def forward(self, x, style):
|
488 |
+
b, c, h, w = x.shape
|
489 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
490 |
+
weight = self.weight * style
|
491 |
+
|
492 |
+
if self.demodulate:
|
493 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
494 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
495 |
+
|
496 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
497 |
+
|
498 |
+
# upsample or downsample if necessary
|
499 |
+
if self.sample_mode == 'upsample':
|
500 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
501 |
+
elif self.sample_mode == 'downsample':
|
502 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
503 |
+
|
504 |
+
b, c, h, w = x.shape
|
505 |
+
x = x.view(1, b * c, h, w)
|
506 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
507 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
508 |
+
return out
|
509 |
+
|
510 |
+
def __repr__(self):
|
511 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
512 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
513 |
+
|
514 |
+
|
515 |
+
class StyleConv(nn.Module):
|
516 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
517 |
+
super(StyleConv, self).__init__()
|
518 |
+
self.modulated_conv = ModulatedConv2d(
|
519 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
520 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
521 |
+
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
522 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
523 |
+
|
524 |
+
def forward(self, x, style, noise=None):
|
525 |
+
# modulate
|
526 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
527 |
+
# noise injection
|
528 |
+
if noise is None:
|
529 |
+
b, _, h, w = out.shape
|
530 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
531 |
+
out = out + self.weight * noise
|
532 |
+
# add bias
|
533 |
+
out = out + self.bias
|
534 |
+
# activation
|
535 |
+
out = self.activate(out)
|
536 |
+
return out
|
537 |
+
|
538 |
+
|
539 |
+
class ToRGB(nn.Module):
|
540 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
541 |
+
super(ToRGB, self).__init__()
|
542 |
+
self.upsample = upsample
|
543 |
+
self.modulated_conv = ModulatedConv2d(
|
544 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
545 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
546 |
+
|
547 |
+
def forward(self, x, style, skip=None):
|
548 |
+
out = self.modulated_conv(x, style)
|
549 |
+
out = out + self.bias
|
550 |
+
if skip is not None:
|
551 |
+
if self.upsample:
|
552 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
553 |
+
out = out + skip
|
554 |
+
return out
|
models/ffc.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fast Fourier Convolution NeurIPS 2020
|
2 |
+
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
|
3 |
+
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
# from models.modules.squeeze_excitation import SELayer
|
9 |
+
import torch.fft
|
10 |
+
|
11 |
+
class SELayer(nn.Module):
|
12 |
+
def __init__(self, channel, reduction=16):
|
13 |
+
super(SELayer, self).__init__()
|
14 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
15 |
+
self.fc = nn.Sequential(
|
16 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
17 |
+
nn.ReLU(inplace=True),
|
18 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
19 |
+
nn.Sigmoid()
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
b, c, _, _ = x.size()
|
24 |
+
y = self.avg_pool(x).view(b, c)
|
25 |
+
y = self.fc(y).view(b, c, 1, 1)
|
26 |
+
res = x * y.expand_as(x)
|
27 |
+
return res
|
28 |
+
|
29 |
+
|
30 |
+
class FFCSE_block(nn.Module):
|
31 |
+
def __init__(self, channels, ratio_g):
|
32 |
+
super(FFCSE_block, self).__init__()
|
33 |
+
in_cg = int(channels * ratio_g)
|
34 |
+
in_cl = channels - in_cg
|
35 |
+
r = 16
|
36 |
+
|
37 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
38 |
+
self.conv1 = nn.Conv2d(channels, channels // r,
|
39 |
+
kernel_size=1, bias=True)
|
40 |
+
self.relu1 = nn.ReLU(inplace=True)
|
41 |
+
self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
|
42 |
+
channels // r, in_cl, kernel_size=1, bias=True)
|
43 |
+
self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
|
44 |
+
channels // r, in_cg, kernel_size=1, bias=True)
|
45 |
+
self.sigmoid = nn.Sigmoid()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = x if type(x) is tuple else (x, 0)
|
49 |
+
id_l, id_g = x
|
50 |
+
|
51 |
+
x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
|
52 |
+
x = self.avgpool(x)
|
53 |
+
x = self.relu1(self.conv1(x))
|
54 |
+
|
55 |
+
x_l = 0 if self.conv_a2l is None else id_l * \
|
56 |
+
self.sigmoid(self.conv_a2l(x))
|
57 |
+
x_g = 0 if self.conv_a2g is None else id_g * \
|
58 |
+
self.sigmoid(self.conv_a2g(x))
|
59 |
+
return x_l, x_g
|
60 |
+
|
61 |
+
|
62 |
+
class FourierUnit(nn.Module):
|
63 |
+
|
64 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
65 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
66 |
+
# bn_layer not used
|
67 |
+
super(FourierUnit, self).__init__()
|
68 |
+
self.groups = groups
|
69 |
+
|
70 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
71 |
+
out_channels=out_channels * 2,
|
72 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
73 |
+
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
|
74 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
75 |
+
|
76 |
+
# squeeze and excitation block
|
77 |
+
self.use_se = use_se
|
78 |
+
if use_se:
|
79 |
+
if se_kwargs is None:
|
80 |
+
se_kwargs = {}
|
81 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
82 |
+
|
83 |
+
self.spatial_scale_factor = spatial_scale_factor
|
84 |
+
self.spatial_scale_mode = spatial_scale_mode
|
85 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
86 |
+
self.ffc3d = ffc3d
|
87 |
+
self.fft_norm = fft_norm
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
batch = x.shape[0]
|
91 |
+
|
92 |
+
if self.spatial_scale_factor is not None:
|
93 |
+
orig_size = x.shape[-2:]
|
94 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
|
95 |
+
|
96 |
+
r_size = x.size()
|
97 |
+
# (batch, c, h, w/2+1, 2)
|
98 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
99 |
+
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
100 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
101 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
102 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
103 |
+
|
104 |
+
if self.spectral_pos_encoding:
|
105 |
+
height, width = ffted.shape[-2:]
|
106 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
107 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
108 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
109 |
+
|
110 |
+
if self.use_se:
|
111 |
+
ffted = self.se(ffted)
|
112 |
+
|
113 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
114 |
+
ffted = self.relu(self.bn(ffted))
|
115 |
+
|
116 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
117 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
118 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
119 |
+
|
120 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
121 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
122 |
+
|
123 |
+
if self.spatial_scale_factor is not None:
|
124 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
125 |
+
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
class SpectralTransform(nn.Module):
|
130 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
|
131 |
+
# bn_layer not used
|
132 |
+
super(SpectralTransform, self).__init__()
|
133 |
+
self.enable_lfu = enable_lfu
|
134 |
+
if stride == 2:
|
135 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
136 |
+
else:
|
137 |
+
self.downsample = nn.Identity()
|
138 |
+
|
139 |
+
self.stride = stride
|
140 |
+
self.conv1 = nn.Sequential(
|
141 |
+
nn.Conv2d(in_channels, out_channels //
|
142 |
+
2, kernel_size=1, groups=groups, bias=False),
|
143 |
+
nn.BatchNorm2d(out_channels // 2),
|
144 |
+
nn.ReLU(inplace=True)
|
145 |
+
)
|
146 |
+
self.fu = FourierUnit(
|
147 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
148 |
+
if self.enable_lfu:
|
149 |
+
self.lfu = FourierUnit(
|
150 |
+
out_channels // 2, out_channels // 2, groups)
|
151 |
+
self.conv2 = torch.nn.Conv2d(
|
152 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
x = self.downsample(x)
|
156 |
+
x = self.conv1(x)
|
157 |
+
output = self.fu(x)
|
158 |
+
|
159 |
+
if self.enable_lfu:
|
160 |
+
n, c, h, w = x.shape
|
161 |
+
split_no = 2
|
162 |
+
split_s = h // split_no
|
163 |
+
xs = torch.cat(torch.split(
|
164 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
165 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
166 |
+
dim=1).contiguous()
|
167 |
+
xs = self.lfu(xs)
|
168 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
169 |
+
else:
|
170 |
+
xs = 0
|
171 |
+
|
172 |
+
output = self.conv2(x + output + xs)
|
173 |
+
return output
|
174 |
+
|
175 |
+
|
176 |
+
class FFC(nn.Module):
|
177 |
+
|
178 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
179 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
180 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
181 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
182 |
+
super(FFC, self).__init__()
|
183 |
+
|
184 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
185 |
+
self.stride = stride
|
186 |
+
|
187 |
+
in_cg = int(in_channels * ratio_gin)
|
188 |
+
in_cl = in_channels - in_cg
|
189 |
+
out_cg = int(out_channels * ratio_gout)
|
190 |
+
out_cl = out_channels - out_cg
|
191 |
+
|
192 |
+
self.ratio_gin = ratio_gin
|
193 |
+
self.ratio_gout = ratio_gout
|
194 |
+
self.global_in_num = in_cg
|
195 |
+
|
196 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
197 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
198 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
199 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
200 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
201 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
202 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
203 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
204 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
205 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
206 |
+
self.convg2g = module(
|
207 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
208 |
+
|
209 |
+
self.gated = gated
|
210 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
211 |
+
self.gate = module(in_channels, 2, 1)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
215 |
+
out_xl, out_xg = 0, 0
|
216 |
+
|
217 |
+
if self.gated:
|
218 |
+
total_input_parts = [x_l]
|
219 |
+
if torch.is_tensor(x_g):
|
220 |
+
total_input_parts.append(x_g)
|
221 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
222 |
+
|
223 |
+
gates = torch.sigmoid(self.gate(total_input))
|
224 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
225 |
+
else:
|
226 |
+
g2l_gate, l2g_gate = 1, 1
|
227 |
+
|
228 |
+
if self.ratio_gout != 1:
|
229 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
230 |
+
if self.ratio_gout != 0:
|
231 |
+
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
|
232 |
+
|
233 |
+
return out_xl, out_xg
|
models/transformer.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class GELU(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super(GELU, self).__init__()
|
14 |
+
def forward(self, x):
|
15 |
+
return 0.5*x*(1+F.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3))))
|
16 |
+
|
17 |
+
# helpers
|
18 |
+
|
19 |
+
def pair(t):
|
20 |
+
return t if isinstance(t, tuple) else (t, t)
|
21 |
+
|
22 |
+
# classes
|
23 |
+
|
24 |
+
class PreNorm(nn.Module):
|
25 |
+
def __init__(self, dim, fn):
|
26 |
+
super().__init__()
|
27 |
+
self.norm = nn.LayerNorm(dim)
|
28 |
+
self.fn = fn
|
29 |
+
def forward(self, x, **kwargs):
|
30 |
+
return self.fn(self.norm(x), **kwargs)
|
31 |
+
|
32 |
+
class DualPreNorm(nn.Module):
|
33 |
+
def __init__(self, dim, fn):
|
34 |
+
super().__init__()
|
35 |
+
self.normx = nn.LayerNorm(dim)
|
36 |
+
self.normy = nn.LayerNorm(dim)
|
37 |
+
self.fn = fn
|
38 |
+
def forward(self, x, y, **kwargs):
|
39 |
+
return self.fn(self.normx(x), self.normy(y), **kwargs)
|
40 |
+
|
41 |
+
class FeedForward(nn.Module):
|
42 |
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
43 |
+
super().__init__()
|
44 |
+
self.net = nn.Sequential(
|
45 |
+
nn.Linear(dim, hidden_dim),
|
46 |
+
GELU(),
|
47 |
+
nn.Dropout(dropout),
|
48 |
+
nn.Linear(hidden_dim, dim),
|
49 |
+
nn.Dropout(dropout)
|
50 |
+
)
|
51 |
+
def forward(self, x):
|
52 |
+
return self.net(x)
|
53 |
+
|
54 |
+
class Attention(nn.Module):
|
55 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
56 |
+
super().__init__()
|
57 |
+
inner_dim = dim_head * heads
|
58 |
+
project_out = not (heads == 1 and dim_head == dim)
|
59 |
+
|
60 |
+
self.heads = heads
|
61 |
+
self.scale = dim_head ** -0.5
|
62 |
+
|
63 |
+
self.attend = nn.Softmax(dim = -1)
|
64 |
+
|
65 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
66 |
+
self.to_k = nn.Linear(dim, inner_dim, bias = False)
|
67 |
+
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
68 |
+
|
69 |
+
|
70 |
+
self.to_out = nn.Sequential(
|
71 |
+
nn.Linear(inner_dim, dim),
|
72 |
+
nn.Dropout(dropout)
|
73 |
+
) if project_out else nn.Identity()
|
74 |
+
|
75 |
+
def forward(self, x, y):
|
76 |
+
# qk = self.to_qk(x).chunk(2, dim = -1) #
|
77 |
+
q = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h = self.heads) # q,k from the zero feature
|
78 |
+
k = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h = self.heads) # v from the reference features
|
79 |
+
v = rearrange(self.to_v(y), 'b n (h d) -> b h n d', h = self.heads)
|
80 |
+
|
81 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
82 |
+
|
83 |
+
attn = self.attend(dots)
|
84 |
+
|
85 |
+
out = torch.matmul(attn, v)
|
86 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
87 |
+
return self.to_out(out)
|
88 |
+
|
89 |
+
class Transformer(nn.Module):
|
90 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
91 |
+
super().__init__()
|
92 |
+
self.layers = nn.ModuleList([])
|
93 |
+
for _ in range(depth):
|
94 |
+
self.layers.append(nn.ModuleList([
|
95 |
+
DualPreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
96 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
97 |
+
]))
|
98 |
+
|
99 |
+
|
100 |
+
def forward(self, x, y): # x is the cropped, y is the foreign reference
|
101 |
+
bs,c,h,w = x.size()
|
102 |
+
|
103 |
+
# img to embedding
|
104 |
+
x = x.view(bs,c,-1).permute(0,2,1)
|
105 |
+
y = y.view(bs,c,-1).permute(0,2,1)
|
106 |
+
|
107 |
+
for attn, ff in self.layers:
|
108 |
+
x = attn(x, y) + x
|
109 |
+
x = ff(x) + x
|
110 |
+
|
111 |
+
x = x.view(bs,h,w,c).permute(0,3,1,2)
|
112 |
+
return x
|
113 |
+
|
114 |
+
class RETURNX(nn.Module):
|
115 |
+
def __init__(self,):
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
def forward(self, x, y): # x is the cropped, y is the foreign reference
|
119 |
+
return x
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
basicsr==1.4.2
|
2 |
+
dlib==19.24.2
|
3 |
+
docopt==0.6.2
|
4 |
+
dominate==2.8.0
|
5 |
+
easydict==1.10
|
6 |
+
einops==0.7.0
|
7 |
+
face_alignment==1.4.1
|
8 |
+
facexlib==0.3.0
|
9 |
+
gradio==3.46.1
|
10 |
+
imageio==2.31.5
|
11 |
+
insightface==0.7.3
|
12 |
+
iou==0.1.0
|
13 |
+
kornia==0.7.0
|
14 |
+
librosa==0.8.0
|
15 |
+
matplotlib==3.7.1
|
16 |
+
menpo==0.11.0
|
17 |
+
mxnet==1.9.1
|
18 |
+
numpy==1.23.5
|
19 |
+
onnx==1.14.1
|
20 |
+
onnxruntime==1.16.0
|
21 |
+
onnxsim==0.4.33
|
third_part/GFPGAN/LICENSE
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tencent is pleased to support the open source community by making GFPGAN available.
|
2 |
+
|
3 |
+
Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
|
4 |
+
|
5 |
+
GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
|
6 |
+
|
7 |
+
|
8 |
+
Terms of the Apache License Version 2.0:
|
9 |
+
---------------------------------------------
|
10 |
+
Apache License
|
11 |
+
|
12 |
+
Version 2.0, January 2004
|
13 |
+
|
14 |
+
http://www.apache.org/licenses/
|
15 |
+
|
16 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
17 |
+
1. Definitions.
|
18 |
+
|
19 |
+
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
20 |
+
|
21 |
+
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
22 |
+
|
23 |
+
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
|
26 |
+
|
27 |
+
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
28 |
+
|
29 |
+
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
30 |
+
|
31 |
+
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
32 |
+
|
33 |
+
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
34 |
+
|
35 |
+
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
|
36 |
+
|
37 |
+
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
38 |
+
|
39 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
40 |
+
|
41 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
42 |
+
|
43 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
44 |
+
|
45 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
46 |
+
|
47 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
48 |
+
|
49 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
50 |
+
|
51 |
+
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
52 |
+
|
53 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
54 |
+
|
55 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
56 |
+
|
57 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
58 |
+
|
59 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
60 |
+
|
61 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
62 |
+
|
63 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
64 |
+
|
65 |
+
END OF TERMS AND CONDITIONS
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
Other dependencies and licenses:
|
70 |
+
|
71 |
+
|
72 |
+
Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
|
73 |
+
---------------------------------------------
|
74 |
+
1. basicsr
|
75 |
+
Copyright 2018-2020 BasicSR Authors
|
76 |
+
|
77 |
+
|
78 |
+
This BasicSR project is released under the Apache 2.0 license.
|
79 |
+
|
80 |
+
A copy of Apache 2.0 is included in this file.
|
81 |
+
|
82 |
+
StyleGAN2
|
83 |
+
The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
|
84 |
+
The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
|
85 |
+
DFDNet
|
86 |
+
The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
|
87 |
+
|
88 |
+
Terms of the Nvidia License:
|
89 |
+
---------------------------------------------
|
90 |
+
|
91 |
+
1. Definitions
|
92 |
+
|
93 |
+
"Licensor" means any person or entity that distributes its Work.
|
94 |
+
|
95 |
+
"Software" means the original work of authorship made available under
|
96 |
+
this License.
|
97 |
+
|
98 |
+
"Work" means the Software and any additions to or derivative works of
|
99 |
+
the Software that are made available under this License.
|
100 |
+
|
101 |
+
"Nvidia Processors" means any central processing unit (CPU), graphics
|
102 |
+
processing unit (GPU), field-programmable gate array (FPGA),
|
103 |
+
application-specific integrated circuit (ASIC) or any combination
|
104 |
+
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
105 |
+
|
106 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
107 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
108 |
+
provided, however, that for the purposes of this License, derivative
|
109 |
+
works shall not include works that remain separable from, or merely
|
110 |
+
link (or bind by name) to the interfaces of, the Work.
|
111 |
+
|
112 |
+
Works, including the Software, are "made available" under this License
|
113 |
+
by including in or with the Work either (a) a copyright notice
|
114 |
+
referencing the applicability of this License to the Work, or (b) a
|
115 |
+
copy of this License.
|
116 |
+
|
117 |
+
2. License Grants
|
118 |
+
|
119 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
120 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
121 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
122 |
+
prepare derivative works of, publicly display, publicly perform,
|
123 |
+
sublicense and distribute its Work and any resulting derivative
|
124 |
+
works in any form.
|
125 |
+
|
126 |
+
3. Limitations
|
127 |
+
|
128 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
129 |
+
if (a) you do so under this License, (b) you include a complete
|
130 |
+
copy of this License with your distribution, and (c) you retain
|
131 |
+
without modification any copyright, patent, trademark, or
|
132 |
+
attribution notices that are present in the Work.
|
133 |
+
|
134 |
+
3.2 Derivative Works. You may specify that additional or different
|
135 |
+
terms apply to the use, reproduction, and distribution of your
|
136 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
137 |
+
provide that the use limitation in Section 3.3 applies to your
|
138 |
+
derivative works, and (b) you identify the specific derivative
|
139 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
140 |
+
this License (including the redistribution requirements in Section
|
141 |
+
3.1) will continue to apply to the Work itself.
|
142 |
+
|
143 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
144 |
+
may be used or intended for use non-commercially. The Work or
|
145 |
+
derivative works thereof may be used or intended for use by Nvidia
|
146 |
+
or its affiliates commercially or non-commercially. As used herein,
|
147 |
+
"non-commercially" means for research or evaluation purposes only.
|
148 |
+
|
149 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
150 |
+
against any Licensor (including any claim, cross-claim or
|
151 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
152 |
+
are infringed by any Work, then your rights under this License from
|
153 |
+
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
154 |
+
terminate immediately.
|
155 |
+
|
156 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
157 |
+
Licensor's or its affiliates' names, logos, or trademarks, except
|
158 |
+
as necessary to reproduce the notices described in this License.
|
159 |
+
|
160 |
+
3.6 Termination. If you violate any term of this License, then your
|
161 |
+
rights under this License (including the grants in Sections 2.1 and
|
162 |
+
2.2) will terminate immediately.
|
163 |
+
|
164 |
+
4. Disclaimer of Warranty.
|
165 |
+
|
166 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
167 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
168 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
169 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
170 |
+
THIS LICENSE.
|
171 |
+
|
172 |
+
5. Limitation of Liability.
|
173 |
+
|
174 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
175 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
176 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
177 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
178 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
179 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
180 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
181 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
182 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
183 |
+
|
184 |
+
MIT License
|
185 |
+
|
186 |
+
Copyright (c) 2019 Kim Seonghyeon
|
187 |
+
|
188 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
189 |
+
of this software and associated documentation files (the "Software"), to deal
|
190 |
+
in the Software without restriction, including without limitation the rights
|
191 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
192 |
+
copies of the Software, and to permit persons to whom the Software is
|
193 |
+
furnished to do so, subject to the following conditions:
|
194 |
+
|
195 |
+
The above copyright notice and this permission notice shall be included in all
|
196 |
+
copies or substantial portions of the Software.
|
197 |
+
|
198 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
199 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
200 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
201 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
202 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
203 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
204 |
+
SOFTWARE.
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
Open Source Software licensed under the BSD 3-Clause license:
|
209 |
+
---------------------------------------------
|
210 |
+
1. torchvision
|
211 |
+
Copyright (c) Soumith Chintala 2016,
|
212 |
+
All rights reserved.
|
213 |
+
|
214 |
+
2. torch
|
215 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
216 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
217 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
218 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
219 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
220 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
221 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
222 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
223 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
224 |
+
|
225 |
+
|
226 |
+
Terms of the BSD 3-Clause License:
|
227 |
+
---------------------------------------------
|
228 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
229 |
+
|
230 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
231 |
+
|
232 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
233 |
+
|
234 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
235 |
+
|
236 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
241 |
+
---------------------------------------------
|
242 |
+
1. numpy
|
243 |
+
Copyright (c) 2005-2020, NumPy Developers.
|
244 |
+
All rights reserved.
|
245 |
+
|
246 |
+
A copy of BSD 3-Clause License is included in this file.
|
247 |
+
|
248 |
+
The NumPy repository and source distributions bundle several libraries that are
|
249 |
+
compatibly licensed. We list these here.
|
250 |
+
|
251 |
+
Name: Numpydoc
|
252 |
+
Files: doc/sphinxext/numpydoc/*
|
253 |
+
License: BSD-2-Clause
|
254 |
+
For details, see doc/sphinxext/LICENSE.txt
|
255 |
+
|
256 |
+
Name: scipy-sphinx-theme
|
257 |
+
Files: doc/scipy-sphinx-theme/*
|
258 |
+
License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
|
259 |
+
For details, see doc/scipy-sphinx-theme/LICENSE.txt
|
260 |
+
|
261 |
+
Name: lapack-lite
|
262 |
+
Files: numpy/linalg/lapack_lite/*
|
263 |
+
License: BSD-3-Clause
|
264 |
+
For details, see numpy/linalg/lapack_lite/LICENSE.txt
|
265 |
+
|
266 |
+
Name: tempita
|
267 |
+
Files: tools/npy_tempita/*
|
268 |
+
License: MIT
|
269 |
+
For details, see tools/npy_tempita/license.txt
|
270 |
+
|
271 |
+
Name: dragon4
|
272 |
+
Files: numpy/core/src/multiarray/dragon4.c
|
273 |
+
License: MIT
|
274 |
+
For license text, see numpy/core/src/multiarray/dragon4.c
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
Open Source Software licensed under the MIT license:
|
279 |
+
---------------------------------------------
|
280 |
+
1. facexlib
|
281 |
+
Copyright (c) 2020 Xintao Wang
|
282 |
+
|
283 |
+
2. opencv-python
|
284 |
+
Copyright (c) Olli-Pekka Heinisuo
|
285 |
+
Please note that only files in cv2 package are used.
|
286 |
+
|
287 |
+
|
288 |
+
Terms of the MIT License:
|
289 |
+
---------------------------------------------
|
290 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
291 |
+
|
292 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
293 |
+
|
294 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
|
299 |
+
---------------------------------------------
|
300 |
+
1. tqdm
|
301 |
+
Copyright (c) 2013 noamraph
|
302 |
+
|
303 |
+
`tqdm` is a product of collaborative work.
|
304 |
+
Unless otherwise stated, all authors (see commit logs) retain copyright
|
305 |
+
for their respective work, and release the work under the MIT licence
|
306 |
+
(text below).
|
307 |
+
|
308 |
+
Exceptions or notable authors are listed below
|
309 |
+
in reverse chronological order:
|
310 |
+
|
311 |
+
* files: *
|
312 |
+
MPLv2.0 2015-2020 (c) Casper da Costa-Luis
|
313 |
+
[casperdcl](https://github.com/casperdcl).
|
314 |
+
* files: tqdm/_tqdm.py
|
315 |
+
MIT 2016 (c) [PR #96] on behalf of Google Inc.
|
316 |
+
* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
|
317 |
+
MIT 2013 (c) Noam Yorav-Raphael, original author.
|
318 |
+
|
319 |
+
[PR #96]: https://github.com/tqdm/tqdm/pull/96
|
320 |
+
|
321 |
+
|
322 |
+
Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
|
323 |
+
-----------------------------------------------
|
324 |
+
|
325 |
+
This Source Code Form is subject to the terms of the
|
326 |
+
Mozilla Public License, v. 2.0.
|
327 |
+
If a copy of the MPL was not distributed with this file,
|
328 |
+
You can obtain one at https://mozilla.org/MPL/2.0/.
|
329 |
+
|
330 |
+
|
331 |
+
MIT License (MIT)
|
332 |
+
-----------------
|
333 |
+
|
334 |
+
Copyright (c) 2013 noamraph
|
335 |
+
|
336 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
337 |
+
this software and associated documentation files (the "Software"), to deal in
|
338 |
+
the Software without restriction, including without limitation the rights to
|
339 |
+
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
340 |
+
the Software, and to permit persons to whom the Software is furnished to do so,
|
341 |
+
subject to the following conditions:
|
342 |
+
|
343 |
+
The above copyright notice and this permission notice shall be included in all
|
344 |
+
copies or substantial portions of the Software.
|
345 |
+
|
346 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
347 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
348 |
+
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
349 |
+
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
350 |
+
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
351 |
+
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
third_part/GFPGAN/gfpgan/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .models import *
|
6 |
+
from .utils import *
|
7 |
+
|
8 |
+
# from .version import *
|
third_part/GFPGAN/gfpgan/archs/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import arch modules for registry
|
6 |
+
# scan all the files that end with '_arch.py' under the archs folder
|
7 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
9 |
+
# import all the arch modules
|
10 |
+
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
|
third_part/GFPGAN/gfpgan/archs/arcface_arch.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
3 |
+
|
4 |
+
|
5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
inplanes (int): Channel number of inputs.
|
10 |
+
outplanes (int): Channel number of outputs.
|
11 |
+
stride (int): Stride in convolution. Default: 1.
|
12 |
+
"""
|
13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
14 |
+
|
15 |
+
|
16 |
+
class BasicBlock(nn.Module):
|
17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
inplanes (int): Channel number of inputs.
|
21 |
+
planes (int): Channel number of outputs.
|
22 |
+
stride (int): Stride in convolution. Default: 1.
|
23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
24 |
+
"""
|
25 |
+
expansion = 1 # output channel expansion ratio
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
residual = self.downsample(x)
|
49 |
+
|
50 |
+
out += residual
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class IRBlock(nn.Module):
|
57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
inplanes (int): Channel number of inputs.
|
61 |
+
planes (int): Channel number of outputs.
|
62 |
+
stride (int): Stride in convolution. Default: 1.
|
63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
65 |
+
"""
|
66 |
+
expansion = 1 # output channel expansion ratio
|
67 |
+
|
68 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
69 |
+
super(IRBlock, self).__init__()
|
70 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
71 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
72 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
73 |
+
self.prelu = nn.PReLU()
|
74 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
75 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
76 |
+
self.downsample = downsample
|
77 |
+
self.stride = stride
|
78 |
+
self.use_se = use_se
|
79 |
+
if self.use_se:
|
80 |
+
self.se = SEBlock(planes)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
residual = x
|
84 |
+
out = self.bn0(x)
|
85 |
+
out = self.conv1(out)
|
86 |
+
out = self.bn1(out)
|
87 |
+
out = self.prelu(out)
|
88 |
+
|
89 |
+
out = self.conv2(out)
|
90 |
+
out = self.bn2(out)
|
91 |
+
if self.use_se:
|
92 |
+
out = self.se(out)
|
93 |
+
|
94 |
+
if self.downsample is not None:
|
95 |
+
residual = self.downsample(x)
|
96 |
+
|
97 |
+
out += residual
|
98 |
+
out = self.prelu(out)
|
99 |
+
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class Bottleneck(nn.Module):
|
104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
inplanes (int): Channel number of inputs.
|
108 |
+
planes (int): Channel number of outputs.
|
109 |
+
stride (int): Stride in convolution. Default: 1.
|
110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
111 |
+
"""
|
112 |
+
expansion = 4 # output channel expansion ratio
|
113 |
+
|
114 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
115 |
+
super(Bottleneck, self).__init__()
|
116 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
117 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
118 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
119 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
120 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
121 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
122 |
+
self.relu = nn.ReLU(inplace=True)
|
123 |
+
self.downsample = downsample
|
124 |
+
self.stride = stride
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
residual = x
|
128 |
+
|
129 |
+
out = self.conv1(x)
|
130 |
+
out = self.bn1(out)
|
131 |
+
out = self.relu(out)
|
132 |
+
|
133 |
+
out = self.conv2(out)
|
134 |
+
out = self.bn2(out)
|
135 |
+
out = self.relu(out)
|
136 |
+
|
137 |
+
out = self.conv3(out)
|
138 |
+
out = self.bn3(out)
|
139 |
+
|
140 |
+
if self.downsample is not None:
|
141 |
+
residual = self.downsample(x)
|
142 |
+
|
143 |
+
out += residual
|
144 |
+
out = self.relu(out)
|
145 |
+
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class SEBlock(nn.Module):
|
150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
channel (int): Channel number of inputs.
|
154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, channel, reduction=16):
|
158 |
+
super(SEBlock, self).__init__()
|
159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
160 |
+
self.fc = nn.Sequential(
|
161 |
+
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
162 |
+
nn.Sigmoid())
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
b, c, _, _ = x.size()
|
166 |
+
y = self.avg_pool(x).view(b, c)
|
167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
168 |
+
return x * y
|
169 |
+
|
170 |
+
|
171 |
+
@ARCH_REGISTRY.register()
|
172 |
+
class ResNetArcFace(nn.Module):
|
173 |
+
"""ArcFace with ResNet architectures.
|
174 |
+
|
175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
block (str): Block used in the ArcFace architecture.
|
179 |
+
layers (tuple(int)): Block numbers in each layer.
|
180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, block, layers, use_se=True):
|
184 |
+
if block == 'IRBlock':
|
185 |
+
block = IRBlock
|
186 |
+
self.inplanes = 64
|
187 |
+
self.use_se = use_se
|
188 |
+
super(ResNetArcFace, self).__init__()
|
189 |
+
|
190 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
192 |
+
self.prelu = nn.PReLU()
|
193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
198 |
+
self.bn4 = nn.BatchNorm2d(512)
|
199 |
+
self.dropout = nn.Dropout()
|
200 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
201 |
+
self.bn5 = nn.BatchNorm1d(512)
|
202 |
+
|
203 |
+
# initialization
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv2d):
|
206 |
+
nn.init.xavier_normal_(m.weight)
|
207 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
208 |
+
nn.init.constant_(m.weight, 1)
|
209 |
+
nn.init.constant_(m.bias, 0)
|
210 |
+
elif isinstance(m, nn.Linear):
|
211 |
+
nn.init.xavier_normal_(m.weight)
|
212 |
+
nn.init.constant_(m.bias, 0)
|
213 |
+
|
214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
215 |
+
downsample = None
|
216 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
217 |
+
downsample = nn.Sequential(
|
218 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
219 |
+
nn.BatchNorm2d(planes * block.expansion),
|
220 |
+
)
|
221 |
+
layers = []
|
222 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
223 |
+
self.inplanes = planes
|
224 |
+
for _ in range(1, num_blocks):
|
225 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
226 |
+
|
227 |
+
return nn.Sequential(*layers)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
x = self.conv1(x)
|
231 |
+
x = self.bn1(x)
|
232 |
+
x = self.prelu(x)
|
233 |
+
x = self.maxpool(x)
|
234 |
+
|
235 |
+
x = self.layer1(x)
|
236 |
+
x = self.layer2(x)
|
237 |
+
x = self.layer3(x)
|
238 |
+
x = self.layer4(x)
|
239 |
+
x = self.bn4(x)
|
240 |
+
x = self.dropout(x)
|
241 |
+
x = x.view(x.size(0), -1)
|
242 |
+
x = self.fc5(x)
|
243 |
+
x = self.bn5(x)
|
244 |
+
|
245 |
+
return x
|
third_part/GFPGAN/gfpgan/archs/gfpgan_bilinear_arch.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from .gfpganv1_arch import ResUpBlock
|
8 |
+
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
9 |
+
StyleGAN2GeneratorBilinear)
|
10 |
+
|
11 |
+
|
12 |
+
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
|
13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
14 |
+
|
15 |
+
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
|
16 |
+
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
out_size (int): The spatial size of outputs.
|
20 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
21 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
22 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
23 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
24 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
25 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
out_size,
|
30 |
+
num_style_feat=512,
|
31 |
+
num_mlp=8,
|
32 |
+
channel_multiplier=2,
|
33 |
+
lr_mlp=0.01,
|
34 |
+
narrow=1,
|
35 |
+
sft_half=False):
|
36 |
+
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
|
37 |
+
out_size,
|
38 |
+
num_style_feat=num_style_feat,
|
39 |
+
num_mlp=num_mlp,
|
40 |
+
channel_multiplier=channel_multiplier,
|
41 |
+
lr_mlp=lr_mlp,
|
42 |
+
narrow=narrow)
|
43 |
+
self.sft_half = sft_half
|
44 |
+
|
45 |
+
def forward(self,
|
46 |
+
styles,
|
47 |
+
conditions,
|
48 |
+
input_is_latent=False,
|
49 |
+
noise=None,
|
50 |
+
randomize_noise=True,
|
51 |
+
truncation=1,
|
52 |
+
truncation_latent=None,
|
53 |
+
inject_index=None,
|
54 |
+
return_latents=False):
|
55 |
+
"""Forward function for StyleGAN2GeneratorBilinearSFT.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
styles (list[Tensor]): Sample codes of styles.
|
59 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
60 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
61 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
62 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
63 |
+
truncation (float): The truncation ratio. Default: 1.
|
64 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
65 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
66 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
67 |
+
"""
|
68 |
+
# style codes -> latents with Style MLP layer
|
69 |
+
if not input_is_latent:
|
70 |
+
styles = [self.style_mlp(s) for s in styles]
|
71 |
+
# noises
|
72 |
+
if noise is None:
|
73 |
+
if randomize_noise:
|
74 |
+
noise = [None] * self.num_layers # for each style conv layer
|
75 |
+
else: # use the stored noise
|
76 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
77 |
+
# style truncation
|
78 |
+
if truncation < 1:
|
79 |
+
style_truncation = []
|
80 |
+
for style in styles:
|
81 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
82 |
+
styles = style_truncation
|
83 |
+
# get style latents with injection
|
84 |
+
if len(styles) == 1:
|
85 |
+
inject_index = self.num_latent
|
86 |
+
|
87 |
+
if styles[0].ndim < 3:
|
88 |
+
# repeat latent code for all the layers
|
89 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
90 |
+
else: # used for encoder with different latent code for each layer
|
91 |
+
latent = styles[0]
|
92 |
+
elif len(styles) == 2: # mixing noises
|
93 |
+
if inject_index is None:
|
94 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
95 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
96 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
97 |
+
latent = torch.cat([latent1, latent2], 1)
|
98 |
+
|
99 |
+
# main generation
|
100 |
+
out = self.constant_input(latent.shape[0])
|
101 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
102 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
103 |
+
|
104 |
+
i = 1
|
105 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
106 |
+
noise[2::2], self.to_rgbs):
|
107 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
108 |
+
|
109 |
+
# the conditions may have fewer levels
|
110 |
+
if i < len(conditions):
|
111 |
+
# SFT part to combine the conditions
|
112 |
+
if self.sft_half: # only apply SFT to half of the channels
|
113 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
114 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
115 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
116 |
+
else: # apply SFT to all the channels
|
117 |
+
out = out * conditions[i - 1] + conditions[i]
|
118 |
+
|
119 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
120 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
121 |
+
i += 2
|
122 |
+
|
123 |
+
image = skip
|
124 |
+
|
125 |
+
if return_latents:
|
126 |
+
return image, latent
|
127 |
+
else:
|
128 |
+
return image, None
|
129 |
+
|
130 |
+
|
131 |
+
@ARCH_REGISTRY.register()
|
132 |
+
class GFPGANBilinear(nn.Module):
|
133 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
134 |
+
|
135 |
+
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
|
136 |
+
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
|
137 |
+
|
138 |
+
|
139 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
out_size (int): The spatial size of outputs.
|
143 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
144 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
145 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
146 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
147 |
+
|
148 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
149 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
150 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
151 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
152 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
153 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
out_size,
|
159 |
+
num_style_feat=512,
|
160 |
+
channel_multiplier=1,
|
161 |
+
decoder_load_path=None,
|
162 |
+
fix_decoder=True,
|
163 |
+
# for stylegan decoder
|
164 |
+
num_mlp=8,
|
165 |
+
lr_mlp=0.01,
|
166 |
+
input_is_latent=False,
|
167 |
+
different_w=False,
|
168 |
+
narrow=1,
|
169 |
+
sft_half=False):
|
170 |
+
|
171 |
+
super(GFPGANBilinear, self).__init__()
|
172 |
+
self.input_is_latent = input_is_latent
|
173 |
+
self.different_w = different_w
|
174 |
+
self.num_style_feat = num_style_feat
|
175 |
+
|
176 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
177 |
+
channels = {
|
178 |
+
'4': int(512 * unet_narrow),
|
179 |
+
'8': int(512 * unet_narrow),
|
180 |
+
'16': int(512 * unet_narrow),
|
181 |
+
'32': int(512 * unet_narrow),
|
182 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
183 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
184 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
185 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
186 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
187 |
+
}
|
188 |
+
|
189 |
+
self.log_size = int(math.log(out_size, 2))
|
190 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
191 |
+
|
192 |
+
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
193 |
+
|
194 |
+
# downsample
|
195 |
+
in_channels = channels[f'{first_out_size}']
|
196 |
+
self.conv_body_down = nn.ModuleList()
|
197 |
+
for i in range(self.log_size, 2, -1):
|
198 |
+
out_channels = channels[f'{2**(i - 1)}']
|
199 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels))
|
200 |
+
in_channels = out_channels
|
201 |
+
|
202 |
+
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
203 |
+
|
204 |
+
# upsample
|
205 |
+
in_channels = channels['4']
|
206 |
+
self.conv_body_up = nn.ModuleList()
|
207 |
+
for i in range(3, self.log_size + 1):
|
208 |
+
out_channels = channels[f'{2**i}']
|
209 |
+
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
210 |
+
in_channels = out_channels
|
211 |
+
|
212 |
+
# to RGB
|
213 |
+
self.toRGB = nn.ModuleList()
|
214 |
+
for i in range(3, self.log_size + 1):
|
215 |
+
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
216 |
+
|
217 |
+
if different_w:
|
218 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
219 |
+
else:
|
220 |
+
linear_out_channel = num_style_feat
|
221 |
+
|
222 |
+
self.final_linear = EqualLinear(
|
223 |
+
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
224 |
+
|
225 |
+
# the decoder: stylegan2 generator with SFT modulations
|
226 |
+
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
|
227 |
+
out_size=out_size,
|
228 |
+
num_style_feat=num_style_feat,
|
229 |
+
num_mlp=num_mlp,
|
230 |
+
channel_multiplier=channel_multiplier,
|
231 |
+
lr_mlp=lr_mlp,
|
232 |
+
narrow=narrow,
|
233 |
+
sft_half=sft_half)
|
234 |
+
|
235 |
+
# load pre-trained stylegan2 model if necessary
|
236 |
+
if decoder_load_path:
|
237 |
+
self.stylegan_decoder.load_state_dict(
|
238 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
239 |
+
# fix decoder without updating params
|
240 |
+
if fix_decoder:
|
241 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
242 |
+
param.requires_grad = False
|
243 |
+
|
244 |
+
# for SFT modulations (scale and shift)
|
245 |
+
self.condition_scale = nn.ModuleList()
|
246 |
+
self.condition_shift = nn.ModuleList()
|
247 |
+
for i in range(3, self.log_size + 1):
|
248 |
+
out_channels = channels[f'{2**i}']
|
249 |
+
if sft_half:
|
250 |
+
sft_out_channels = out_channels
|
251 |
+
else:
|
252 |
+
sft_out_channels = out_channels * 2
|
253 |
+
self.condition_scale.append(
|
254 |
+
nn.Sequential(
|
255 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
256 |
+
ScaledLeakyReLU(0.2),
|
257 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
258 |
+
self.condition_shift.append(
|
259 |
+
nn.Sequential(
|
260 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
261 |
+
ScaledLeakyReLU(0.2),
|
262 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
263 |
+
|
264 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
265 |
+
"""Forward function for GFPGANBilinear.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
x (Tensor): Input images.
|
269 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
270 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
271 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
272 |
+
"""
|
273 |
+
conditions = []
|
274 |
+
unet_skips = []
|
275 |
+
out_rgbs = []
|
276 |
+
|
277 |
+
# encoder
|
278 |
+
feat = self.conv_body_first(x)
|
279 |
+
for i in range(self.log_size - 2):
|
280 |
+
feat = self.conv_body_down[i](feat)
|
281 |
+
unet_skips.insert(0, feat)
|
282 |
+
|
283 |
+
feat = self.final_conv(feat)
|
284 |
+
|
285 |
+
# style code
|
286 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
287 |
+
if self.different_w:
|
288 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
289 |
+
|
290 |
+
# decode
|
291 |
+
for i in range(self.log_size - 2):
|
292 |
+
# add unet skip
|
293 |
+
feat = feat + unet_skips[i]
|
294 |
+
# ResUpLayer
|
295 |
+
feat = self.conv_body_up[i](feat)
|
296 |
+
# generate scale and shift for SFT layers
|
297 |
+
scale = self.condition_scale[i](feat)
|
298 |
+
conditions.append(scale.clone())
|
299 |
+
shift = self.condition_shift[i](feat)
|
300 |
+
conditions.append(shift.clone())
|
301 |
+
# generate rgb images
|
302 |
+
if return_rgb:
|
303 |
+
out_rgbs.append(self.toRGB[i](feat))
|
304 |
+
|
305 |
+
# decoder
|
306 |
+
image, _ = self.stylegan_decoder([style_code],
|
307 |
+
conditions,
|
308 |
+
return_latents=return_latents,
|
309 |
+
input_is_latent=self.input_is_latent,
|
310 |
+
randomize_noise=randomize_noise)
|
311 |
+
|
312 |
+
return image, out_rgbs
|
third_part/GFPGAN/gfpgan/archs/gfpganv1_arch.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
5 |
+
StyleGAN2Generator)
|
6 |
+
from basicsr.ops.fused_act import FusedLeakyReLU
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
14 |
+
|
15 |
+
Args:
|
16 |
+
out_size (int): The spatial size of outputs.
|
17 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
18 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
19 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
20 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
21 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
22 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
23 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
24 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
out_size,
|
29 |
+
num_style_feat=512,
|
30 |
+
num_mlp=8,
|
31 |
+
channel_multiplier=2,
|
32 |
+
resample_kernel=(1, 3, 3, 1),
|
33 |
+
lr_mlp=0.01,
|
34 |
+
narrow=1,
|
35 |
+
sft_half=False):
|
36 |
+
super(StyleGAN2GeneratorSFT, self).__init__(
|
37 |
+
out_size,
|
38 |
+
num_style_feat=num_style_feat,
|
39 |
+
num_mlp=num_mlp,
|
40 |
+
channel_multiplier=channel_multiplier,
|
41 |
+
resample_kernel=resample_kernel,
|
42 |
+
lr_mlp=lr_mlp,
|
43 |
+
narrow=narrow)
|
44 |
+
self.sft_half = sft_half
|
45 |
+
|
46 |
+
def forward(self,
|
47 |
+
styles,
|
48 |
+
conditions,
|
49 |
+
input_is_latent=False,
|
50 |
+
noise=None,
|
51 |
+
randomize_noise=True,
|
52 |
+
truncation=1,
|
53 |
+
truncation_latent=None,
|
54 |
+
inject_index=None,
|
55 |
+
return_latents=False):
|
56 |
+
"""Forward function for StyleGAN2GeneratorSFT.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
styles (list[Tensor]): Sample codes of styles.
|
60 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
61 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
62 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
63 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
64 |
+
truncation (float): The truncation ratio. Default: 1.
|
65 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
66 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
67 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
68 |
+
"""
|
69 |
+
# style codes -> latents with Style MLP layer
|
70 |
+
if not input_is_latent:
|
71 |
+
styles = [self.style_mlp(s) for s in styles]
|
72 |
+
# noises
|
73 |
+
if noise is None:
|
74 |
+
if randomize_noise:
|
75 |
+
noise = [None] * self.num_layers # for each style conv layer
|
76 |
+
else: # use the stored noise
|
77 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
78 |
+
# style truncation
|
79 |
+
if truncation < 1:
|
80 |
+
style_truncation = []
|
81 |
+
for style in styles:
|
82 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
83 |
+
styles = style_truncation
|
84 |
+
# get style latents with injection
|
85 |
+
if len(styles) == 1:
|
86 |
+
inject_index = self.num_latent
|
87 |
+
|
88 |
+
if styles[0].ndim < 3:
|
89 |
+
# repeat latent code for all the layers
|
90 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
91 |
+
else: # used for encoder with different latent code for each layer
|
92 |
+
latent = styles[0]
|
93 |
+
elif len(styles) == 2: # mixing noises
|
94 |
+
if inject_index is None:
|
95 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
96 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
97 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
98 |
+
latent = torch.cat([latent1, latent2], 1)
|
99 |
+
|
100 |
+
# main generation
|
101 |
+
out = self.constant_input(latent.shape[0])
|
102 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
103 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
104 |
+
|
105 |
+
i = 1
|
106 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
107 |
+
noise[2::2], self.to_rgbs):
|
108 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
109 |
+
|
110 |
+
# the conditions may have fewer levels
|
111 |
+
if i < len(conditions):
|
112 |
+
# SFT part to combine the conditions
|
113 |
+
if self.sft_half: # only apply SFT to half of the channels
|
114 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
115 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
116 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
117 |
+
else: # apply SFT to all the channels
|
118 |
+
out = out * conditions[i - 1] + conditions[i]
|
119 |
+
|
120 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
121 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
122 |
+
i += 2
|
123 |
+
|
124 |
+
image = skip
|
125 |
+
|
126 |
+
if return_latents:
|
127 |
+
return image, latent
|
128 |
+
else:
|
129 |
+
return image, None
|
130 |
+
|
131 |
+
|
132 |
+
class ConvUpLayer(nn.Module):
|
133 |
+
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
in_channels (int): Channel number of the input.
|
137 |
+
out_channels (int): Channel number of the output.
|
138 |
+
kernel_size (int): Size of the convolving kernel.
|
139 |
+
stride (int): Stride of the convolution. Default: 1
|
140 |
+
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
141 |
+
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
142 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
143 |
+
activate (bool): Whether use activateion. Default: True.
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self,
|
147 |
+
in_channels,
|
148 |
+
out_channels,
|
149 |
+
kernel_size,
|
150 |
+
stride=1,
|
151 |
+
padding=0,
|
152 |
+
bias=True,
|
153 |
+
bias_init_val=0,
|
154 |
+
activate=True):
|
155 |
+
super(ConvUpLayer, self).__init__()
|
156 |
+
self.in_channels = in_channels
|
157 |
+
self.out_channels = out_channels
|
158 |
+
self.kernel_size = kernel_size
|
159 |
+
self.stride = stride
|
160 |
+
self.padding = padding
|
161 |
+
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
162 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
163 |
+
|
164 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
165 |
+
|
166 |
+
if bias and not activate:
|
167 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
168 |
+
else:
|
169 |
+
self.register_parameter('bias', None)
|
170 |
+
|
171 |
+
# activation
|
172 |
+
if activate:
|
173 |
+
if bias:
|
174 |
+
self.activation = FusedLeakyReLU(out_channels)
|
175 |
+
else:
|
176 |
+
self.activation = ScaledLeakyReLU(0.2)
|
177 |
+
else:
|
178 |
+
self.activation = None
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
# bilinear upsample
|
182 |
+
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
183 |
+
# conv
|
184 |
+
out = F.conv2d(
|
185 |
+
out,
|
186 |
+
self.weight * self.scale,
|
187 |
+
bias=self.bias,
|
188 |
+
stride=self.stride,
|
189 |
+
padding=self.padding,
|
190 |
+
)
|
191 |
+
# activation
|
192 |
+
if self.activation is not None:
|
193 |
+
out = self.activation(out)
|
194 |
+
return out
|
195 |
+
|
196 |
+
|
197 |
+
class ResUpBlock(nn.Module):
|
198 |
+
"""Residual block with upsampling.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
in_channels (int): Channel number of the input.
|
202 |
+
out_channels (int): Channel number of the output.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, in_channels, out_channels):
|
206 |
+
super(ResUpBlock, self).__init__()
|
207 |
+
|
208 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
209 |
+
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
|
210 |
+
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
out = self.conv1(x)
|
214 |
+
out = self.conv2(out)
|
215 |
+
skip = self.skip(x)
|
216 |
+
out = (out + skip) / math.sqrt(2)
|
217 |
+
return out
|
218 |
+
|
219 |
+
|
220 |
+
@ARCH_REGISTRY.register()
|
221 |
+
class GFPGANv1(nn.Module):
|
222 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
223 |
+
|
224 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
out_size (int): The spatial size of outputs.
|
228 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
229 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
230 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
231 |
+
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
232 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
233 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
234 |
+
|
235 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
236 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
237 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
238 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
239 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
240 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
241 |
+
"""
|
242 |
+
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
out_size,
|
246 |
+
num_style_feat=512,
|
247 |
+
channel_multiplier=1,
|
248 |
+
resample_kernel=(1, 3, 3, 1),
|
249 |
+
decoder_load_path=None,
|
250 |
+
fix_decoder=True,
|
251 |
+
# for stylegan decoder
|
252 |
+
num_mlp=8,
|
253 |
+
lr_mlp=0.01,
|
254 |
+
input_is_latent=False,
|
255 |
+
different_w=False,
|
256 |
+
narrow=1,
|
257 |
+
sft_half=False):
|
258 |
+
|
259 |
+
super(GFPGANv1, self).__init__()
|
260 |
+
self.input_is_latent = input_is_latent
|
261 |
+
self.different_w = different_w
|
262 |
+
self.num_style_feat = num_style_feat
|
263 |
+
|
264 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
265 |
+
channels = {
|
266 |
+
'4': int(512 * unet_narrow),
|
267 |
+
'8': int(512 * unet_narrow),
|
268 |
+
'16': int(512 * unet_narrow),
|
269 |
+
'32': int(512 * unet_narrow),
|
270 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
271 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
272 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
273 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
274 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
275 |
+
}
|
276 |
+
|
277 |
+
self.log_size = int(math.log(out_size, 2))
|
278 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
279 |
+
|
280 |
+
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
281 |
+
|
282 |
+
# downsample
|
283 |
+
in_channels = channels[f'{first_out_size}']
|
284 |
+
self.conv_body_down = nn.ModuleList()
|
285 |
+
for i in range(self.log_size, 2, -1):
|
286 |
+
out_channels = channels[f'{2**(i - 1)}']
|
287 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
|
288 |
+
in_channels = out_channels
|
289 |
+
|
290 |
+
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
291 |
+
|
292 |
+
# upsample
|
293 |
+
in_channels = channels['4']
|
294 |
+
self.conv_body_up = nn.ModuleList()
|
295 |
+
for i in range(3, self.log_size + 1):
|
296 |
+
out_channels = channels[f'{2**i}']
|
297 |
+
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
298 |
+
in_channels = out_channels
|
299 |
+
|
300 |
+
# to RGB
|
301 |
+
self.toRGB = nn.ModuleList()
|
302 |
+
for i in range(3, self.log_size + 1):
|
303 |
+
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
304 |
+
|
305 |
+
if different_w:
|
306 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
307 |
+
else:
|
308 |
+
linear_out_channel = num_style_feat
|
309 |
+
|
310 |
+
self.final_linear = EqualLinear(
|
311 |
+
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
312 |
+
|
313 |
+
# the decoder: stylegan2 generator with SFT modulations
|
314 |
+
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
315 |
+
out_size=out_size,
|
316 |
+
num_style_feat=num_style_feat,
|
317 |
+
num_mlp=num_mlp,
|
318 |
+
channel_multiplier=channel_multiplier,
|
319 |
+
resample_kernel=resample_kernel,
|
320 |
+
lr_mlp=lr_mlp,
|
321 |
+
narrow=narrow,
|
322 |
+
sft_half=sft_half)
|
323 |
+
|
324 |
+
# load pre-trained stylegan2 model if necessary
|
325 |
+
if decoder_load_path:
|
326 |
+
self.stylegan_decoder.load_state_dict(
|
327 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
328 |
+
# fix decoder without updating params
|
329 |
+
if fix_decoder:
|
330 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
331 |
+
param.requires_grad = False
|
332 |
+
|
333 |
+
# for SFT modulations (scale and shift)
|
334 |
+
self.condition_scale = nn.ModuleList()
|
335 |
+
self.condition_shift = nn.ModuleList()
|
336 |
+
for i in range(3, self.log_size + 1):
|
337 |
+
out_channels = channels[f'{2**i}']
|
338 |
+
if sft_half:
|
339 |
+
sft_out_channels = out_channels
|
340 |
+
else:
|
341 |
+
sft_out_channels = out_channels * 2
|
342 |
+
self.condition_scale.append(
|
343 |
+
nn.Sequential(
|
344 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
345 |
+
ScaledLeakyReLU(0.2),
|
346 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
347 |
+
self.condition_shift.append(
|
348 |
+
nn.Sequential(
|
349 |
+
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
350 |
+
ScaledLeakyReLU(0.2),
|
351 |
+
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
352 |
+
|
353 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
354 |
+
"""Forward function for GFPGANv1.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
x (Tensor): Input images.
|
358 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
359 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
360 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
361 |
+
"""
|
362 |
+
conditions = []
|
363 |
+
unet_skips = []
|
364 |
+
out_rgbs = []
|
365 |
+
|
366 |
+
# encoder
|
367 |
+
feat = self.conv_body_first(x)
|
368 |
+
for i in range(self.log_size - 2):
|
369 |
+
feat = self.conv_body_down[i](feat)
|
370 |
+
unet_skips.insert(0, feat)
|
371 |
+
|
372 |
+
feat = self.final_conv(feat)
|
373 |
+
|
374 |
+
# style code
|
375 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
376 |
+
if self.different_w:
|
377 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
378 |
+
|
379 |
+
# decode
|
380 |
+
for i in range(self.log_size - 2):
|
381 |
+
# add unet skip
|
382 |
+
feat = feat + unet_skips[i]
|
383 |
+
# ResUpLayer
|
384 |
+
feat = self.conv_body_up[i](feat)
|
385 |
+
# generate scale and shift for SFT layers
|
386 |
+
scale = self.condition_scale[i](feat)
|
387 |
+
conditions.append(scale.clone())
|
388 |
+
shift = self.condition_shift[i](feat)
|
389 |
+
conditions.append(shift.clone())
|
390 |
+
# generate rgb images
|
391 |
+
if return_rgb:
|
392 |
+
out_rgbs.append(self.toRGB[i](feat))
|
393 |
+
|
394 |
+
# decoder
|
395 |
+
image, _ = self.stylegan_decoder([style_code],
|
396 |
+
conditions,
|
397 |
+
return_latents=return_latents,
|
398 |
+
input_is_latent=self.input_is_latent,
|
399 |
+
randomize_noise=randomize_noise)
|
400 |
+
|
401 |
+
return image, out_rgbs
|
402 |
+
|
403 |
+
|
404 |
+
@ARCH_REGISTRY.register()
|
405 |
+
class FacialComponentDiscriminator(nn.Module):
|
406 |
+
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(self):
|
410 |
+
super(FacialComponentDiscriminator, self).__init__()
|
411 |
+
# It now uses a VGG-style architectrue with fixed model size
|
412 |
+
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
413 |
+
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
414 |
+
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
415 |
+
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
416 |
+
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
417 |
+
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
418 |
+
|
419 |
+
def forward(self, x, return_feats=False):
|
420 |
+
"""Forward function for FacialComponentDiscriminator.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
x (Tensor): Input images.
|
424 |
+
return_feats (bool): Whether to return intermediate features. Default: False.
|
425 |
+
"""
|
426 |
+
feat = self.conv1(x)
|
427 |
+
feat = self.conv3(self.conv2(feat))
|
428 |
+
rlt_feats = []
|
429 |
+
if return_feats:
|
430 |
+
rlt_feats.append(feat.clone())
|
431 |
+
feat = self.conv5(self.conv4(feat))
|
432 |
+
if return_feats:
|
433 |
+
rlt_feats.append(feat.clone())
|
434 |
+
out = self.final_conv(feat)
|
435 |
+
|
436 |
+
if return_feats:
|
437 |
+
return out, rlt_feats
|
438 |
+
else:
|
439 |
+
return out, None
|
third_part/GFPGAN/gfpgan/archs/gfpganv1_clean_arch.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
9 |
+
|
10 |
+
|
11 |
+
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
12 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
13 |
+
|
14 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
out_size (int): The spatial size of outputs.
|
18 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
19 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
20 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
21 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
22 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
26 |
+
super(StyleGAN2GeneratorCSFT, self).__init__(
|
27 |
+
out_size,
|
28 |
+
num_style_feat=num_style_feat,
|
29 |
+
num_mlp=num_mlp,
|
30 |
+
channel_multiplier=channel_multiplier,
|
31 |
+
narrow=narrow)
|
32 |
+
self.sft_half = sft_half
|
33 |
+
|
34 |
+
def forward(self,
|
35 |
+
styles,
|
36 |
+
conditions,
|
37 |
+
input_is_latent=False,
|
38 |
+
noise=None,
|
39 |
+
randomize_noise=True,
|
40 |
+
truncation=1,
|
41 |
+
truncation_latent=None,
|
42 |
+
inject_index=None,
|
43 |
+
return_latents=False):
|
44 |
+
"""Forward function for StyleGAN2GeneratorCSFT.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
styles (list[Tensor]): Sample codes of styles.
|
48 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
49 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
50 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
51 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
52 |
+
truncation (float): The truncation ratio. Default: 1.
|
53 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
54 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
55 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
56 |
+
"""
|
57 |
+
# style codes -> latents with Style MLP layer
|
58 |
+
if not input_is_latent:
|
59 |
+
styles = [self.style_mlp(s) for s in styles]
|
60 |
+
# noises
|
61 |
+
if noise is None:
|
62 |
+
if randomize_noise:
|
63 |
+
noise = [None] * self.num_layers # for each style conv layer
|
64 |
+
else: # use the stored noise
|
65 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
66 |
+
# style truncation
|
67 |
+
if truncation < 1:
|
68 |
+
style_truncation = []
|
69 |
+
for style in styles:
|
70 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
71 |
+
styles = style_truncation
|
72 |
+
# get style latents with injection
|
73 |
+
if len(styles) == 1:
|
74 |
+
inject_index = self.num_latent
|
75 |
+
|
76 |
+
if styles[0].ndim < 3:
|
77 |
+
# repeat latent code for all the layers
|
78 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
79 |
+
else: # used for encoder with different latent code for each layer
|
80 |
+
latent = styles[0]
|
81 |
+
elif len(styles) == 2: # mixing noises
|
82 |
+
if inject_index is None:
|
83 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
84 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
85 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
86 |
+
latent = torch.cat([latent1, latent2], 1)
|
87 |
+
|
88 |
+
# main generation
|
89 |
+
out = self.constant_input(latent.shape[0])
|
90 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
91 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
92 |
+
|
93 |
+
i = 1
|
94 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
95 |
+
noise[2::2], self.to_rgbs):
|
96 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
97 |
+
|
98 |
+
# the conditions may have fewer levels
|
99 |
+
if i < len(conditions):
|
100 |
+
# SFT part to combine the conditions
|
101 |
+
if self.sft_half: # only apply SFT to half of the channels
|
102 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
103 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
104 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
105 |
+
else: # apply SFT to all the channels
|
106 |
+
out = out * conditions[i - 1] + conditions[i]
|
107 |
+
|
108 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
109 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
110 |
+
i += 2
|
111 |
+
|
112 |
+
image = skip
|
113 |
+
|
114 |
+
if return_latents:
|
115 |
+
return image, latent
|
116 |
+
else:
|
117 |
+
return image, None
|
118 |
+
|
119 |
+
|
120 |
+
class ResBlock(nn.Module):
|
121 |
+
"""Residual block with bilinear upsampling/downsampling.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
in_channels (int): Channel number of the input.
|
125 |
+
out_channels (int): Channel number of the output.
|
126 |
+
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
130 |
+
super(ResBlock, self).__init__()
|
131 |
+
|
132 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
133 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
134 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
135 |
+
if mode == 'down':
|
136 |
+
self.scale_factor = 0.5
|
137 |
+
elif mode == 'up':
|
138 |
+
self.scale_factor = 2
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
142 |
+
# upsample/downsample
|
143 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
144 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
145 |
+
# skip
|
146 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
147 |
+
skip = self.skip(x)
|
148 |
+
out = out + skip
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
@ARCH_REGISTRY.register()
|
153 |
+
class GFPGANv1Clean(nn.Module):
|
154 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
155 |
+
|
156 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
157 |
+
|
158 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
out_size (int): The spatial size of outputs.
|
162 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
163 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
164 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
165 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
166 |
+
|
167 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
168 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
169 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
170 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
171 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
out_size,
|
177 |
+
num_style_feat=512,
|
178 |
+
channel_multiplier=1,
|
179 |
+
decoder_load_path=None,
|
180 |
+
fix_decoder=True,
|
181 |
+
# for stylegan decoder
|
182 |
+
num_mlp=8,
|
183 |
+
input_is_latent=False,
|
184 |
+
different_w=False,
|
185 |
+
narrow=1,
|
186 |
+
sft_half=False):
|
187 |
+
|
188 |
+
super(GFPGANv1Clean, self).__init__()
|
189 |
+
self.input_is_latent = input_is_latent
|
190 |
+
self.different_w = different_w
|
191 |
+
self.num_style_feat = num_style_feat
|
192 |
+
|
193 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
194 |
+
channels = {
|
195 |
+
'4': int(512 * unet_narrow),
|
196 |
+
'8': int(512 * unet_narrow),
|
197 |
+
'16': int(512 * unet_narrow),
|
198 |
+
'32': int(512 * unet_narrow),
|
199 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
200 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
201 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
202 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
203 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
204 |
+
}
|
205 |
+
|
206 |
+
self.log_size = int(math.log(out_size, 2))
|
207 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
208 |
+
|
209 |
+
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
210 |
+
|
211 |
+
# downsample
|
212 |
+
in_channels = channels[f'{first_out_size}']
|
213 |
+
self.conv_body_down = nn.ModuleList()
|
214 |
+
for i in range(self.log_size, 2, -1):
|
215 |
+
out_channels = channels[f'{2**(i - 1)}']
|
216 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
217 |
+
in_channels = out_channels
|
218 |
+
|
219 |
+
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
220 |
+
|
221 |
+
# upsample
|
222 |
+
in_channels = channels['4']
|
223 |
+
self.conv_body_up = nn.ModuleList()
|
224 |
+
for i in range(3, self.log_size + 1):
|
225 |
+
out_channels = channels[f'{2**i}']
|
226 |
+
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
227 |
+
in_channels = out_channels
|
228 |
+
|
229 |
+
# to RGB
|
230 |
+
self.toRGB = nn.ModuleList()
|
231 |
+
for i in range(3, self.log_size + 1):
|
232 |
+
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
233 |
+
|
234 |
+
if different_w:
|
235 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
236 |
+
else:
|
237 |
+
linear_out_channel = num_style_feat
|
238 |
+
|
239 |
+
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
240 |
+
|
241 |
+
# the decoder: stylegan2 generator with SFT modulations
|
242 |
+
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
243 |
+
out_size=out_size,
|
244 |
+
num_style_feat=num_style_feat,
|
245 |
+
num_mlp=num_mlp,
|
246 |
+
channel_multiplier=channel_multiplier,
|
247 |
+
narrow=narrow,
|
248 |
+
sft_half=sft_half)
|
249 |
+
|
250 |
+
# load pre-trained stylegan2 model if necessary
|
251 |
+
if decoder_load_path:
|
252 |
+
self.stylegan_decoder.load_state_dict(
|
253 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
254 |
+
# fix decoder without updating params
|
255 |
+
if fix_decoder:
|
256 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
257 |
+
param.requires_grad = False
|
258 |
+
|
259 |
+
# for SFT modulations (scale and shift)
|
260 |
+
self.condition_scale = nn.ModuleList()
|
261 |
+
self.condition_shift = nn.ModuleList()
|
262 |
+
for i in range(3, self.log_size + 1):
|
263 |
+
out_channels = channels[f'{2**i}']
|
264 |
+
if sft_half:
|
265 |
+
sft_out_channels = out_channels
|
266 |
+
else:
|
267 |
+
sft_out_channels = out_channels * 2
|
268 |
+
self.condition_scale.append(
|
269 |
+
nn.Sequential(
|
270 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
271 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
272 |
+
self.condition_shift.append(
|
273 |
+
nn.Sequential(
|
274 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
275 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
276 |
+
|
277 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
278 |
+
"""Forward function for GFPGANv1Clean.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
x (Tensor): Input images.
|
282 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
283 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
284 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
285 |
+
"""
|
286 |
+
conditions = []
|
287 |
+
unet_skips = []
|
288 |
+
out_rgbs = []
|
289 |
+
|
290 |
+
# encoder
|
291 |
+
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
292 |
+
for i in range(self.log_size - 2):
|
293 |
+
feat = self.conv_body_down[i](feat)
|
294 |
+
unet_skips.insert(0, feat)
|
295 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
296 |
+
|
297 |
+
# style code
|
298 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
299 |
+
if self.different_w:
|
300 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
301 |
+
|
302 |
+
# decode
|
303 |
+
for i in range(self.log_size - 2):
|
304 |
+
# add unet skip
|
305 |
+
feat = feat + unet_skips[i]
|
306 |
+
# ResUpLayer
|
307 |
+
feat = self.conv_body_up[i](feat)
|
308 |
+
# generate scale and shift for SFT layers
|
309 |
+
scale = self.condition_scale[i](feat)
|
310 |
+
conditions.append(scale.clone())
|
311 |
+
shift = self.condition_shift[i](feat)
|
312 |
+
conditions.append(shift.clone())
|
313 |
+
# generate rgb images
|
314 |
+
if return_rgb:
|
315 |
+
out_rgbs.append(self.toRGB[i](feat))
|
316 |
+
|
317 |
+
# decoder
|
318 |
+
image, _ = self.stylegan_decoder([style_code],
|
319 |
+
conditions,
|
320 |
+
return_latents=return_latents,
|
321 |
+
input_is_latent=self.input_is_latent,
|
322 |
+
randomize_noise=randomize_noise)
|
323 |
+
|
324 |
+
return image, out_rgbs
|
third_part/GFPGAN/gfpgan/archs/stylegan2_bilinear_arch.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class NormStyleCode(nn.Module):
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
"""Normalize the style codes.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (Tensor): Style codes with shape (b, c).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
Tensor: Normalized tensor.
|
20 |
+
"""
|
21 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
22 |
+
|
23 |
+
|
24 |
+
class EqualLinear(nn.Module):
|
25 |
+
"""Equalized Linear as StyleGAN2.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
in_channels (int): Size of each sample.
|
29 |
+
out_channels (int): Size of each output sample.
|
30 |
+
bias (bool): If set to ``False``, the layer will not learn an additive
|
31 |
+
bias. Default: ``True``.
|
32 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
33 |
+
lr_mul (float): Learning rate multiplier. Default: 1.
|
34 |
+
activation (None | str): The activation after ``linear`` operation.
|
35 |
+
Supported: 'fused_lrelu', None. Default: None.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
39 |
+
super(EqualLinear, self).__init__()
|
40 |
+
self.in_channels = in_channels
|
41 |
+
self.out_channels = out_channels
|
42 |
+
self.lr_mul = lr_mul
|
43 |
+
self.activation = activation
|
44 |
+
if self.activation not in ['fused_lrelu', None]:
|
45 |
+
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
46 |
+
"Supported ones are: ['fused_lrelu', None].")
|
47 |
+
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
48 |
+
|
49 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
50 |
+
if bias:
|
51 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
52 |
+
else:
|
53 |
+
self.register_parameter('bias', None)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
if self.bias is None:
|
57 |
+
bias = None
|
58 |
+
else:
|
59 |
+
bias = self.bias * self.lr_mul
|
60 |
+
if self.activation == 'fused_lrelu':
|
61 |
+
out = F.linear(x, self.weight * self.scale)
|
62 |
+
out = fused_leaky_relu(out, bias)
|
63 |
+
else:
|
64 |
+
out = F.linear(x, self.weight * self.scale, bias=bias)
|
65 |
+
return out
|
66 |
+
|
67 |
+
def __repr__(self):
|
68 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
69 |
+
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
70 |
+
|
71 |
+
|
72 |
+
class ModulatedConv2d(nn.Module):
|
73 |
+
"""Modulated Conv2d used in StyleGAN2.
|
74 |
+
|
75 |
+
There is no bias in ModulatedConv2d.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
in_channels (int): Channel number of the input.
|
79 |
+
out_channels (int): Channel number of the output.
|
80 |
+
kernel_size (int): Size of the convolving kernel.
|
81 |
+
num_style_feat (int): Channel number of style features.
|
82 |
+
demodulate (bool): Whether to demodulate in the conv layer.
|
83 |
+
Default: True.
|
84 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
85 |
+
Default: None.
|
86 |
+
eps (float): A value added to the denominator for numerical stability.
|
87 |
+
Default: 1e-8.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self,
|
91 |
+
in_channels,
|
92 |
+
out_channels,
|
93 |
+
kernel_size,
|
94 |
+
num_style_feat,
|
95 |
+
demodulate=True,
|
96 |
+
sample_mode=None,
|
97 |
+
eps=1e-8,
|
98 |
+
interpolation_mode='bilinear'):
|
99 |
+
super(ModulatedConv2d, self).__init__()
|
100 |
+
self.in_channels = in_channels
|
101 |
+
self.out_channels = out_channels
|
102 |
+
self.kernel_size = kernel_size
|
103 |
+
self.demodulate = demodulate
|
104 |
+
self.sample_mode = sample_mode
|
105 |
+
self.eps = eps
|
106 |
+
self.interpolation_mode = interpolation_mode
|
107 |
+
if self.interpolation_mode == 'nearest':
|
108 |
+
self.align_corners = None
|
109 |
+
else:
|
110 |
+
self.align_corners = False
|
111 |
+
|
112 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
113 |
+
# modulation inside each modulated conv
|
114 |
+
self.modulation = EqualLinear(
|
115 |
+
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
116 |
+
|
117 |
+
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
118 |
+
self.padding = kernel_size // 2
|
119 |
+
|
120 |
+
def forward(self, x, style):
|
121 |
+
"""Forward function.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
125 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tensor: Modulated tensor after convolution.
|
129 |
+
"""
|
130 |
+
b, c, h, w = x.shape # c = c_in
|
131 |
+
# weight modulation
|
132 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
133 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
134 |
+
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
135 |
+
|
136 |
+
if self.demodulate:
|
137 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
138 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
139 |
+
|
140 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
141 |
+
|
142 |
+
if self.sample_mode == 'upsample':
|
143 |
+
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
144 |
+
elif self.sample_mode == 'downsample':
|
145 |
+
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
|
146 |
+
|
147 |
+
b, c, h, w = x.shape
|
148 |
+
x = x.view(1, b * c, h, w)
|
149 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
150 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
151 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
152 |
+
|
153 |
+
return out
|
154 |
+
|
155 |
+
def __repr__(self):
|
156 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
157 |
+
f'out_channels={self.out_channels}, '
|
158 |
+
f'kernel_size={self.kernel_size}, '
|
159 |
+
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
160 |
+
|
161 |
+
|
162 |
+
class StyleConv(nn.Module):
|
163 |
+
"""Style conv.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
in_channels (int): Channel number of the input.
|
167 |
+
out_channels (int): Channel number of the output.
|
168 |
+
kernel_size (int): Size of the convolving kernel.
|
169 |
+
num_style_feat (int): Channel number of style features.
|
170 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
171 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
172 |
+
Default: None.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self,
|
176 |
+
in_channels,
|
177 |
+
out_channels,
|
178 |
+
kernel_size,
|
179 |
+
num_style_feat,
|
180 |
+
demodulate=True,
|
181 |
+
sample_mode=None,
|
182 |
+
interpolation_mode='bilinear'):
|
183 |
+
super(StyleConv, self).__init__()
|
184 |
+
self.modulated_conv = ModulatedConv2d(
|
185 |
+
in_channels,
|
186 |
+
out_channels,
|
187 |
+
kernel_size,
|
188 |
+
num_style_feat,
|
189 |
+
demodulate=demodulate,
|
190 |
+
sample_mode=sample_mode,
|
191 |
+
interpolation_mode=interpolation_mode)
|
192 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
193 |
+
self.activate = FusedLeakyReLU(out_channels)
|
194 |
+
|
195 |
+
def forward(self, x, style, noise=None):
|
196 |
+
# modulate
|
197 |
+
out = self.modulated_conv(x, style)
|
198 |
+
# noise injection
|
199 |
+
if noise is None:
|
200 |
+
b, _, h, w = out.shape
|
201 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
202 |
+
out = out + self.weight * noise
|
203 |
+
# activation (with bias)
|
204 |
+
out = self.activate(out)
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class ToRGB(nn.Module):
|
209 |
+
"""To RGB from features.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
in_channels (int): Channel number of input.
|
213 |
+
num_style_feat (int): Channel number of style features.
|
214 |
+
upsample (bool): Whether to upsample. Default: True.
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
|
218 |
+
super(ToRGB, self).__init__()
|
219 |
+
self.upsample = upsample
|
220 |
+
self.interpolation_mode = interpolation_mode
|
221 |
+
if self.interpolation_mode == 'nearest':
|
222 |
+
self.align_corners = None
|
223 |
+
else:
|
224 |
+
self.align_corners = False
|
225 |
+
self.modulated_conv = ModulatedConv2d(
|
226 |
+
in_channels,
|
227 |
+
3,
|
228 |
+
kernel_size=1,
|
229 |
+
num_style_feat=num_style_feat,
|
230 |
+
demodulate=False,
|
231 |
+
sample_mode=None,
|
232 |
+
interpolation_mode=interpolation_mode)
|
233 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
234 |
+
|
235 |
+
def forward(self, x, style, skip=None):
|
236 |
+
"""Forward function.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
240 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
241 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
Tensor: RGB images.
|
245 |
+
"""
|
246 |
+
out = self.modulated_conv(x, style)
|
247 |
+
out = out + self.bias
|
248 |
+
if skip is not None:
|
249 |
+
if self.upsample:
|
250 |
+
skip = F.interpolate(
|
251 |
+
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
252 |
+
out = out + skip
|
253 |
+
return out
|
254 |
+
|
255 |
+
|
256 |
+
class ConstantInput(nn.Module):
|
257 |
+
"""Constant input.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
num_channel (int): Channel number of constant input.
|
261 |
+
size (int): Spatial size of constant input.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, num_channel, size):
|
265 |
+
super(ConstantInput, self).__init__()
|
266 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
267 |
+
|
268 |
+
def forward(self, batch):
|
269 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
270 |
+
return out
|
271 |
+
|
272 |
+
|
273 |
+
@ARCH_REGISTRY.register()
|
274 |
+
class StyleGAN2GeneratorBilinear(nn.Module):
|
275 |
+
"""StyleGAN2 Generator.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
out_size (int): The spatial size of outputs.
|
279 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
280 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
281 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
282 |
+
StyleGAN2. Default: 2.
|
283 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
284 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(self,
|
288 |
+
out_size,
|
289 |
+
num_style_feat=512,
|
290 |
+
num_mlp=8,
|
291 |
+
channel_multiplier=2,
|
292 |
+
lr_mlp=0.01,
|
293 |
+
narrow=1,
|
294 |
+
interpolation_mode='bilinear'):
|
295 |
+
super(StyleGAN2GeneratorBilinear, self).__init__()
|
296 |
+
# Style MLP layers
|
297 |
+
self.num_style_feat = num_style_feat
|
298 |
+
style_mlp_layers = [NormStyleCode()]
|
299 |
+
for i in range(num_mlp):
|
300 |
+
style_mlp_layers.append(
|
301 |
+
EqualLinear(
|
302 |
+
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
303 |
+
activation='fused_lrelu'))
|
304 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
305 |
+
|
306 |
+
channels = {
|
307 |
+
'4': int(512 * narrow),
|
308 |
+
'8': int(512 * narrow),
|
309 |
+
'16': int(512 * narrow),
|
310 |
+
'32': int(512 * narrow),
|
311 |
+
'64': int(256 * channel_multiplier * narrow),
|
312 |
+
'128': int(128 * channel_multiplier * narrow),
|
313 |
+
'256': int(64 * channel_multiplier * narrow),
|
314 |
+
'512': int(32 * channel_multiplier * narrow),
|
315 |
+
'1024': int(16 * channel_multiplier * narrow)
|
316 |
+
}
|
317 |
+
self.channels = channels
|
318 |
+
|
319 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
320 |
+
self.style_conv1 = StyleConv(
|
321 |
+
channels['4'],
|
322 |
+
channels['4'],
|
323 |
+
kernel_size=3,
|
324 |
+
num_style_feat=num_style_feat,
|
325 |
+
demodulate=True,
|
326 |
+
sample_mode=None,
|
327 |
+
interpolation_mode=interpolation_mode)
|
328 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
|
329 |
+
|
330 |
+
self.log_size = int(math.log(out_size, 2))
|
331 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
332 |
+
self.num_latent = self.log_size * 2 - 2
|
333 |
+
|
334 |
+
self.style_convs = nn.ModuleList()
|
335 |
+
self.to_rgbs = nn.ModuleList()
|
336 |
+
self.noises = nn.Module()
|
337 |
+
|
338 |
+
in_channels = channels['4']
|
339 |
+
# noise
|
340 |
+
for layer_idx in range(self.num_layers):
|
341 |
+
resolution = 2**((layer_idx + 5) // 2)
|
342 |
+
shape = [1, 1, resolution, resolution]
|
343 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
344 |
+
# style convs and to_rgbs
|
345 |
+
for i in range(3, self.log_size + 1):
|
346 |
+
out_channels = channels[f'{2**i}']
|
347 |
+
self.style_convs.append(
|
348 |
+
StyleConv(
|
349 |
+
in_channels,
|
350 |
+
out_channels,
|
351 |
+
kernel_size=3,
|
352 |
+
num_style_feat=num_style_feat,
|
353 |
+
demodulate=True,
|
354 |
+
sample_mode='upsample',
|
355 |
+
interpolation_mode=interpolation_mode))
|
356 |
+
self.style_convs.append(
|
357 |
+
StyleConv(
|
358 |
+
out_channels,
|
359 |
+
out_channels,
|
360 |
+
kernel_size=3,
|
361 |
+
num_style_feat=num_style_feat,
|
362 |
+
demodulate=True,
|
363 |
+
sample_mode=None,
|
364 |
+
interpolation_mode=interpolation_mode))
|
365 |
+
self.to_rgbs.append(
|
366 |
+
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
|
367 |
+
in_channels = out_channels
|
368 |
+
|
369 |
+
def make_noise(self):
|
370 |
+
"""Make noise for noise injection."""
|
371 |
+
device = self.constant_input.weight.device
|
372 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
373 |
+
|
374 |
+
for i in range(3, self.log_size + 1):
|
375 |
+
for _ in range(2):
|
376 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
377 |
+
|
378 |
+
return noises
|
379 |
+
|
380 |
+
def get_latent(self, x):
|
381 |
+
return self.style_mlp(x)
|
382 |
+
|
383 |
+
def mean_latent(self, num_latent):
|
384 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
385 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
386 |
+
return latent
|
387 |
+
|
388 |
+
def forward(self,
|
389 |
+
styles,
|
390 |
+
input_is_latent=False,
|
391 |
+
noise=None,
|
392 |
+
randomize_noise=True,
|
393 |
+
truncation=1,
|
394 |
+
truncation_latent=None,
|
395 |
+
inject_index=None,
|
396 |
+
return_latents=False):
|
397 |
+
"""Forward function for StyleGAN2Generator.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
styles (list[Tensor]): Sample codes of styles.
|
401 |
+
input_is_latent (bool): Whether input is latent style.
|
402 |
+
Default: False.
|
403 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
404 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is
|
405 |
+
False. Default: True.
|
406 |
+
truncation (float): TODO. Default: 1.
|
407 |
+
truncation_latent (Tensor | None): TODO. Default: None.
|
408 |
+
inject_index (int | None): The injection index for mixing noise.
|
409 |
+
Default: None.
|
410 |
+
return_latents (bool): Whether to return style latents.
|
411 |
+
Default: False.
|
412 |
+
"""
|
413 |
+
# style codes -> latents with Style MLP layer
|
414 |
+
if not input_is_latent:
|
415 |
+
styles = [self.style_mlp(s) for s in styles]
|
416 |
+
# noises
|
417 |
+
if noise is None:
|
418 |
+
if randomize_noise:
|
419 |
+
noise = [None] * self.num_layers # for each style conv layer
|
420 |
+
else: # use the stored noise
|
421 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
422 |
+
# style truncation
|
423 |
+
if truncation < 1:
|
424 |
+
style_truncation = []
|
425 |
+
for style in styles:
|
426 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
427 |
+
styles = style_truncation
|
428 |
+
# get style latent with injection
|
429 |
+
if len(styles) == 1:
|
430 |
+
inject_index = self.num_latent
|
431 |
+
|
432 |
+
if styles[0].ndim < 3:
|
433 |
+
# repeat latent code for all the layers
|
434 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
435 |
+
else: # used for encoder with different latent code for each layer
|
436 |
+
latent = styles[0]
|
437 |
+
elif len(styles) == 2: # mixing noises
|
438 |
+
if inject_index is None:
|
439 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
440 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
441 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
442 |
+
latent = torch.cat([latent1, latent2], 1)
|
443 |
+
|
444 |
+
# main generation
|
445 |
+
out = self.constant_input(latent.shape[0])
|
446 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
447 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
448 |
+
|
449 |
+
i = 1
|
450 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
451 |
+
noise[2::2], self.to_rgbs):
|
452 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
453 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
454 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
455 |
+
i += 2
|
456 |
+
|
457 |
+
image = skip
|
458 |
+
|
459 |
+
if return_latents:
|
460 |
+
return image, latent
|
461 |
+
else:
|
462 |
+
return image, None
|
463 |
+
|
464 |
+
|
465 |
+
class ScaledLeakyReLU(nn.Module):
|
466 |
+
"""Scaled LeakyReLU.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
negative_slope (float): Negative slope. Default: 0.2.
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __init__(self, negative_slope=0.2):
|
473 |
+
super(ScaledLeakyReLU, self).__init__()
|
474 |
+
self.negative_slope = negative_slope
|
475 |
+
|
476 |
+
def forward(self, x):
|
477 |
+
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
478 |
+
return out * math.sqrt(2)
|
479 |
+
|
480 |
+
|
481 |
+
class EqualConv2d(nn.Module):
|
482 |
+
"""Equalized Linear as StyleGAN2.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
in_channels (int): Channel number of the input.
|
486 |
+
out_channels (int): Channel number of the output.
|
487 |
+
kernel_size (int): Size of the convolving kernel.
|
488 |
+
stride (int): Stride of the convolution. Default: 1
|
489 |
+
padding (int): Zero-padding added to both sides of the input.
|
490 |
+
Default: 0.
|
491 |
+
bias (bool): If ``True``, adds a learnable bias to the output.
|
492 |
+
Default: ``True``.
|
493 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
494 |
+
"""
|
495 |
+
|
496 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
497 |
+
super(EqualConv2d, self).__init__()
|
498 |
+
self.in_channels = in_channels
|
499 |
+
self.out_channels = out_channels
|
500 |
+
self.kernel_size = kernel_size
|
501 |
+
self.stride = stride
|
502 |
+
self.padding = padding
|
503 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
504 |
+
|
505 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
506 |
+
if bias:
|
507 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
508 |
+
else:
|
509 |
+
self.register_parameter('bias', None)
|
510 |
+
|
511 |
+
def forward(self, x):
|
512 |
+
out = F.conv2d(
|
513 |
+
x,
|
514 |
+
self.weight * self.scale,
|
515 |
+
bias=self.bias,
|
516 |
+
stride=self.stride,
|
517 |
+
padding=self.padding,
|
518 |
+
)
|
519 |
+
|
520 |
+
return out
|
521 |
+
|
522 |
+
def __repr__(self):
|
523 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
524 |
+
f'out_channels={self.out_channels}, '
|
525 |
+
f'kernel_size={self.kernel_size},'
|
526 |
+
f' stride={self.stride}, padding={self.padding}, '
|
527 |
+
f'bias={self.bias is not None})')
|
528 |
+
|
529 |
+
|
530 |
+
class ConvLayer(nn.Sequential):
|
531 |
+
"""Conv Layer used in StyleGAN2 Discriminator.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
in_channels (int): Channel number of the input.
|
535 |
+
out_channels (int): Channel number of the output.
|
536 |
+
kernel_size (int): Kernel size.
|
537 |
+
downsample (bool): Whether downsample by a factor of 2.
|
538 |
+
Default: False.
|
539 |
+
bias (bool): Whether with bias. Default: True.
|
540 |
+
activate (bool): Whether use activateion. Default: True.
|
541 |
+
"""
|
542 |
+
|
543 |
+
def __init__(self,
|
544 |
+
in_channels,
|
545 |
+
out_channels,
|
546 |
+
kernel_size,
|
547 |
+
downsample=False,
|
548 |
+
bias=True,
|
549 |
+
activate=True,
|
550 |
+
interpolation_mode='bilinear'):
|
551 |
+
layers = []
|
552 |
+
self.interpolation_mode = interpolation_mode
|
553 |
+
# downsample
|
554 |
+
if downsample:
|
555 |
+
if self.interpolation_mode == 'nearest':
|
556 |
+
self.align_corners = None
|
557 |
+
else:
|
558 |
+
self.align_corners = False
|
559 |
+
|
560 |
+
layers.append(
|
561 |
+
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
|
562 |
+
stride = 1
|
563 |
+
self.padding = kernel_size // 2
|
564 |
+
# conv
|
565 |
+
layers.append(
|
566 |
+
EqualConv2d(
|
567 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
568 |
+
and not activate))
|
569 |
+
# activation
|
570 |
+
if activate:
|
571 |
+
if bias:
|
572 |
+
layers.append(FusedLeakyReLU(out_channels))
|
573 |
+
else:
|
574 |
+
layers.append(ScaledLeakyReLU(0.2))
|
575 |
+
|
576 |
+
super(ConvLayer, self).__init__(*layers)
|
577 |
+
|
578 |
+
|
579 |
+
class ResBlock(nn.Module):
|
580 |
+
"""Residual block used in StyleGAN2 Discriminator.
|
581 |
+
|
582 |
+
Args:
|
583 |
+
in_channels (int): Channel number of the input.
|
584 |
+
out_channels (int): Channel number of the output.
|
585 |
+
"""
|
586 |
+
|
587 |
+
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
|
588 |
+
super(ResBlock, self).__init__()
|
589 |
+
|
590 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
591 |
+
self.conv2 = ConvLayer(
|
592 |
+
in_channels,
|
593 |
+
out_channels,
|
594 |
+
3,
|
595 |
+
downsample=True,
|
596 |
+
interpolation_mode=interpolation_mode,
|
597 |
+
bias=True,
|
598 |
+
activate=True)
|
599 |
+
self.skip = ConvLayer(
|
600 |
+
in_channels,
|
601 |
+
out_channels,
|
602 |
+
1,
|
603 |
+
downsample=True,
|
604 |
+
interpolation_mode=interpolation_mode,
|
605 |
+
bias=False,
|
606 |
+
activate=False)
|
607 |
+
|
608 |
+
def forward(self, x):
|
609 |
+
out = self.conv1(x)
|
610 |
+
out = self.conv2(out)
|
611 |
+
skip = self.skip(x)
|
612 |
+
out = (out + skip) / math.sqrt(2)
|
613 |
+
return out
|
third_part/GFPGAN/gfpgan/archs/stylegan2_clean_arch.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.archs.arch_util import default_init_weights
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class NormStyleCode(nn.Module):
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
"""Normalize the style codes.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (Tensor): Style codes with shape (b, c).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
Tensor: Normalized tensor.
|
20 |
+
"""
|
21 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
22 |
+
|
23 |
+
|
24 |
+
class ModulatedConv2d(nn.Module):
|
25 |
+
"""Modulated Conv2d used in StyleGAN2.
|
26 |
+
|
27 |
+
There is no bias in ModulatedConv2d.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
in_channels (int): Channel number of the input.
|
31 |
+
out_channels (int): Channel number of the output.
|
32 |
+
kernel_size (int): Size of the convolving kernel.
|
33 |
+
num_style_feat (int): Channel number of style features.
|
34 |
+
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
35 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
36 |
+
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
in_channels,
|
41 |
+
out_channels,
|
42 |
+
kernel_size,
|
43 |
+
num_style_feat,
|
44 |
+
demodulate=True,
|
45 |
+
sample_mode=None,
|
46 |
+
eps=1e-8):
|
47 |
+
super(ModulatedConv2d, self).__init__()
|
48 |
+
self.in_channels = in_channels
|
49 |
+
self.out_channels = out_channels
|
50 |
+
self.kernel_size = kernel_size
|
51 |
+
self.demodulate = demodulate
|
52 |
+
self.sample_mode = sample_mode
|
53 |
+
self.eps = eps
|
54 |
+
|
55 |
+
# modulation inside each modulated conv
|
56 |
+
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
57 |
+
# initialization
|
58 |
+
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
59 |
+
|
60 |
+
self.weight = nn.Parameter(
|
61 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
62 |
+
math.sqrt(in_channels * kernel_size**2))
|
63 |
+
self.padding = kernel_size // 2
|
64 |
+
|
65 |
+
def forward(self, x, style):
|
66 |
+
"""Forward function.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
70 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Tensor: Modulated tensor after convolution.
|
74 |
+
"""
|
75 |
+
b, c, h, w = x.shape # c = c_in
|
76 |
+
# weight modulation
|
77 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
78 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
79 |
+
weight = self.weight * style # (b, c_out, c_in, k, k)
|
80 |
+
|
81 |
+
if self.demodulate:
|
82 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
83 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
84 |
+
|
85 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
86 |
+
|
87 |
+
# upsample or downsample if necessary
|
88 |
+
if self.sample_mode == 'upsample':
|
89 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
90 |
+
elif self.sample_mode == 'downsample':
|
91 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
92 |
+
|
93 |
+
b, c, h, w = x.shape
|
94 |
+
x = x.view(1, b * c, h, w)
|
95 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
96 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
97 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
98 |
+
|
99 |
+
return out
|
100 |
+
|
101 |
+
def __repr__(self):
|
102 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
103 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
104 |
+
|
105 |
+
|
106 |
+
class StyleConv(nn.Module):
|
107 |
+
"""Style conv used in StyleGAN2.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
in_channels (int): Channel number of the input.
|
111 |
+
out_channels (int): Channel number of the output.
|
112 |
+
kernel_size (int): Size of the convolving kernel.
|
113 |
+
num_style_feat (int): Channel number of style features.
|
114 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
115 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
119 |
+
super(StyleConv, self).__init__()
|
120 |
+
self.modulated_conv = ModulatedConv2d(
|
121 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
122 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
123 |
+
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
124 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
125 |
+
|
126 |
+
def forward(self, x, style, noise=None):
|
127 |
+
# modulate
|
128 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
129 |
+
# noise injection
|
130 |
+
if noise is None:
|
131 |
+
b, _, h, w = out.shape
|
132 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
133 |
+
out = out + self.weight * noise
|
134 |
+
# add bias
|
135 |
+
out = out + self.bias
|
136 |
+
# activation
|
137 |
+
out = self.activate(out)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
class ToRGB(nn.Module):
|
142 |
+
"""To RGB (image space) from features.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
in_channels (int): Channel number of input.
|
146 |
+
num_style_feat (int): Channel number of style features.
|
147 |
+
upsample (bool): Whether to upsample. Default: True.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
151 |
+
super(ToRGB, self).__init__()
|
152 |
+
self.upsample = upsample
|
153 |
+
self.modulated_conv = ModulatedConv2d(
|
154 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
155 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
156 |
+
|
157 |
+
def forward(self, x, style, skip=None):
|
158 |
+
"""Forward function.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
162 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
163 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: RGB images.
|
167 |
+
"""
|
168 |
+
out = self.modulated_conv(x, style)
|
169 |
+
out = out + self.bias
|
170 |
+
if skip is not None:
|
171 |
+
if self.upsample:
|
172 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
173 |
+
out = out + skip
|
174 |
+
return out
|
175 |
+
|
176 |
+
|
177 |
+
class ConstantInput(nn.Module):
|
178 |
+
"""Constant input.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
num_channel (int): Channel number of constant input.
|
182 |
+
size (int): Spatial size of constant input.
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, num_channel, size):
|
186 |
+
super(ConstantInput, self).__init__()
|
187 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
188 |
+
|
189 |
+
def forward(self, batch):
|
190 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
191 |
+
return out
|
192 |
+
|
193 |
+
|
194 |
+
@ARCH_REGISTRY.register()
|
195 |
+
class StyleGAN2GeneratorClean(nn.Module):
|
196 |
+
"""Clean version of StyleGAN2 Generator.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
out_size (int): The spatial size of outputs.
|
200 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
201 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
202 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
203 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
204 |
+
"""
|
205 |
+
|
206 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
207 |
+
super(StyleGAN2GeneratorClean, self).__init__()
|
208 |
+
# Style MLP layers
|
209 |
+
self.num_style_feat = num_style_feat
|
210 |
+
style_mlp_layers = [NormStyleCode()]
|
211 |
+
for i in range(num_mlp):
|
212 |
+
style_mlp_layers.extend(
|
213 |
+
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
214 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
215 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
216 |
+
# initialization
|
217 |
+
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
218 |
+
|
219 |
+
# channel list
|
220 |
+
channels = {
|
221 |
+
'4': int(512 * narrow),
|
222 |
+
'8': int(512 * narrow),
|
223 |
+
'16': int(512 * narrow),
|
224 |
+
'32': int(512 * narrow),
|
225 |
+
'64': int(256 * channel_multiplier * narrow),
|
226 |
+
'128': int(128 * channel_multiplier * narrow),
|
227 |
+
'256': int(64 * channel_multiplier * narrow),
|
228 |
+
'512': int(32 * channel_multiplier * narrow),
|
229 |
+
'1024': int(16 * channel_multiplier * narrow)
|
230 |
+
}
|
231 |
+
self.channels = channels
|
232 |
+
|
233 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
234 |
+
self.style_conv1 = StyleConv(
|
235 |
+
channels['4'],
|
236 |
+
channels['4'],
|
237 |
+
kernel_size=3,
|
238 |
+
num_style_feat=num_style_feat,
|
239 |
+
demodulate=True,
|
240 |
+
sample_mode=None)
|
241 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
242 |
+
|
243 |
+
self.log_size = int(math.log(out_size, 2))
|
244 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
245 |
+
self.num_latent = self.log_size * 2 - 2
|
246 |
+
|
247 |
+
self.style_convs = nn.ModuleList()
|
248 |
+
self.to_rgbs = nn.ModuleList()
|
249 |
+
self.noises = nn.Module()
|
250 |
+
|
251 |
+
in_channels = channels['4']
|
252 |
+
# noise
|
253 |
+
for layer_idx in range(self.num_layers):
|
254 |
+
resolution = 2**((layer_idx + 5) // 2)
|
255 |
+
shape = [1, 1, resolution, resolution]
|
256 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
257 |
+
# style convs and to_rgbs
|
258 |
+
for i in range(3, self.log_size + 1):
|
259 |
+
out_channels = channels[f'{2**i}']
|
260 |
+
self.style_convs.append(
|
261 |
+
StyleConv(
|
262 |
+
in_channels,
|
263 |
+
out_channels,
|
264 |
+
kernel_size=3,
|
265 |
+
num_style_feat=num_style_feat,
|
266 |
+
demodulate=True,
|
267 |
+
sample_mode='upsample'))
|
268 |
+
self.style_convs.append(
|
269 |
+
StyleConv(
|
270 |
+
out_channels,
|
271 |
+
out_channels,
|
272 |
+
kernel_size=3,
|
273 |
+
num_style_feat=num_style_feat,
|
274 |
+
demodulate=True,
|
275 |
+
sample_mode=None))
|
276 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
277 |
+
in_channels = out_channels
|
278 |
+
|
279 |
+
def make_noise(self):
|
280 |
+
"""Make noise for noise injection."""
|
281 |
+
device = self.constant_input.weight.device
|
282 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
283 |
+
|
284 |
+
for i in range(3, self.log_size + 1):
|
285 |
+
for _ in range(2):
|
286 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
287 |
+
|
288 |
+
return noises
|
289 |
+
|
290 |
+
def get_latent(self, x):
|
291 |
+
return self.style_mlp(x)
|
292 |
+
|
293 |
+
def mean_latent(self, num_latent):
|
294 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
295 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
296 |
+
return latent
|
297 |
+
|
298 |
+
def forward(self,
|
299 |
+
styles,
|
300 |
+
input_is_latent=False,
|
301 |
+
noise=None,
|
302 |
+
randomize_noise=True,
|
303 |
+
truncation=1,
|
304 |
+
truncation_latent=None,
|
305 |
+
inject_index=None,
|
306 |
+
return_latents=False):
|
307 |
+
"""Forward function for StyleGAN2GeneratorClean.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
styles (list[Tensor]): Sample codes of styles.
|
311 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
312 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
313 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
314 |
+
truncation (float): The truncation ratio. Default: 1.
|
315 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
316 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
317 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
318 |
+
"""
|
319 |
+
# style codes -> latents with Style MLP layer
|
320 |
+
if not input_is_latent:
|
321 |
+
styles = [self.style_mlp(s) for s in styles]
|
322 |
+
# noises
|
323 |
+
if noise is None:
|
324 |
+
if randomize_noise:
|
325 |
+
noise = [None] * self.num_layers # for each style conv layer
|
326 |
+
else: # use the stored noise
|
327 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
328 |
+
# style truncation
|
329 |
+
if truncation < 1:
|
330 |
+
style_truncation = []
|
331 |
+
for style in styles:
|
332 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
333 |
+
styles = style_truncation
|
334 |
+
# get style latents with injection
|
335 |
+
if len(styles) == 1:
|
336 |
+
inject_index = self.num_latent
|
337 |
+
|
338 |
+
if styles[0].ndim < 3:
|
339 |
+
# repeat latent code for all the layers
|
340 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
341 |
+
else: # used for encoder with different latent code for each layer
|
342 |
+
latent = styles[0]
|
343 |
+
elif len(styles) == 2: # mixing noises
|
344 |
+
if inject_index is None:
|
345 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
346 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
347 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
348 |
+
latent = torch.cat([latent1, latent2], 1)
|
349 |
+
|
350 |
+
# main generation
|
351 |
+
out = self.constant_input(latent.shape[0])
|
352 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
353 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
354 |
+
|
355 |
+
i = 1
|
356 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
357 |
+
noise[2::2], self.to_rgbs):
|
358 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
359 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
360 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
361 |
+
i += 2
|
362 |
+
|
363 |
+
image = skip
|
364 |
+
|
365 |
+
if return_latents:
|
366 |
+
return image, latent
|
367 |
+
else:
|
368 |
+
return image, None
|
third_part/GFPGAN/gfpgan/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import dataset modules for registry
|
6 |
+
# scan all the files that end with '_dataset.py' under the data folder
|
7 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
9 |
+
# import all the dataset modules
|
10 |
+
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
|
third_part/GFPGAN/gfpgan/data/ffhq_degradation_dataset.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os.path as osp
|
5 |
+
import torch
|
6 |
+
import torch.utils.data as data
|
7 |
+
from basicsr.data import degradations as degradations
|
8 |
+
from basicsr.data.data_util import paths_from_folder
|
9 |
+
from basicsr.data.transforms import augment
|
10 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
11 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
12 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
13 |
+
normalize)
|
14 |
+
|
15 |
+
|
16 |
+
@DATASET_REGISTRY.register()
|
17 |
+
class FFHQDegradationDataset(data.Dataset):
|
18 |
+
"""FFHQ dataset for GFPGAN.
|
19 |
+
|
20 |
+
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
24 |
+
dataroot_gt (str): Data root path for gt.
|
25 |
+
io_backend (dict): IO backend type and other kwarg.
|
26 |
+
mean (list | tuple): Image mean.
|
27 |
+
std (list | tuple): Image std.
|
28 |
+
use_hflip (bool): Whether to horizontally flip.
|
29 |
+
Please see more options in the codes.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, opt):
|
33 |
+
super(FFHQDegradationDataset, self).__init__()
|
34 |
+
self.opt = opt
|
35 |
+
# file client (io backend)
|
36 |
+
self.file_client = None
|
37 |
+
self.io_backend_opt = opt['io_backend']
|
38 |
+
|
39 |
+
self.gt_folder = opt['dataroot_gt']
|
40 |
+
self.mean = opt['mean']
|
41 |
+
self.std = opt['std']
|
42 |
+
self.out_size = opt['out_size']
|
43 |
+
|
44 |
+
self.crop_components = opt.get('crop_components', False) # facial components
|
45 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
46 |
+
|
47 |
+
if self.crop_components:
|
48 |
+
# load component list from a pre-process pth files
|
49 |
+
self.components_list = torch.load(opt.get('component_path'))
|
50 |
+
|
51 |
+
# file client (lmdb io backend)
|
52 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
53 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
54 |
+
if not self.gt_folder.endswith('.lmdb'):
|
55 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
56 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
57 |
+
self.paths = [line.split('.')[0] for line in fin]
|
58 |
+
else:
|
59 |
+
# disk backend: scan file list from a folder
|
60 |
+
self.paths = paths_from_folder(self.gt_folder)
|
61 |
+
|
62 |
+
# degradation configurations
|
63 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
64 |
+
self.kernel_list = opt['kernel_list']
|
65 |
+
self.kernel_prob = opt['kernel_prob']
|
66 |
+
self.blur_sigma = opt['blur_sigma']
|
67 |
+
self.downsample_range = opt['downsample_range']
|
68 |
+
self.noise_range = opt['noise_range']
|
69 |
+
self.jpeg_range = opt['jpeg_range']
|
70 |
+
|
71 |
+
# color jitter
|
72 |
+
self.color_jitter_prob = opt.get('color_jitter_prob')
|
73 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
74 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
75 |
+
# to gray
|
76 |
+
self.gray_prob = opt.get('gray_prob')
|
77 |
+
|
78 |
+
logger = get_root_logger()
|
79 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
80 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
81 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
82 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
83 |
+
|
84 |
+
if self.color_jitter_prob is not None:
|
85 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
86 |
+
if self.gray_prob is not None:
|
87 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
88 |
+
self.color_jitter_shift /= 255.
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def color_jitter(img, shift):
|
92 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
93 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
94 |
+
img = img + jitter_val
|
95 |
+
img = np.clip(img, 0, 1)
|
96 |
+
return img
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
100 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
101 |
+
fn_idx = torch.randperm(4)
|
102 |
+
for fn_id in fn_idx:
|
103 |
+
if fn_id == 0 and brightness is not None:
|
104 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
105 |
+
img = adjust_brightness(img, brightness_factor)
|
106 |
+
|
107 |
+
if fn_id == 1 and contrast is not None:
|
108 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
109 |
+
img = adjust_contrast(img, contrast_factor)
|
110 |
+
|
111 |
+
if fn_id == 2 and saturation is not None:
|
112 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
113 |
+
img = adjust_saturation(img, saturation_factor)
|
114 |
+
|
115 |
+
if fn_id == 3 and hue is not None:
|
116 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
117 |
+
img = adjust_hue(img, hue_factor)
|
118 |
+
return img
|
119 |
+
|
120 |
+
def get_component_coordinates(self, index, status):
|
121 |
+
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
122 |
+
components_bbox = self.components_list[f'{index:08d}']
|
123 |
+
if status[0]: # hflip
|
124 |
+
# exchange right and left eye
|
125 |
+
tmp = components_bbox['left_eye']
|
126 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
127 |
+
components_bbox['right_eye'] = tmp
|
128 |
+
# modify the width coordinate
|
129 |
+
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
130 |
+
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
131 |
+
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
132 |
+
|
133 |
+
# get coordinates
|
134 |
+
locations = []
|
135 |
+
for part in ['left_eye', 'right_eye', 'mouth']:
|
136 |
+
mean = components_bbox[part][0:2]
|
137 |
+
half_len = components_bbox[part][2]
|
138 |
+
if 'eye' in part:
|
139 |
+
half_len *= self.eye_enlarge_ratio
|
140 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
141 |
+
loc = torch.from_numpy(loc).float()
|
142 |
+
locations.append(loc)
|
143 |
+
return locations
|
144 |
+
|
145 |
+
def __getitem__(self, index):
|
146 |
+
if self.file_client is None:
|
147 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
148 |
+
|
149 |
+
# load gt image
|
150 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
151 |
+
gt_path = self.paths[index]
|
152 |
+
img_bytes = self.file_client.get(gt_path)
|
153 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
154 |
+
|
155 |
+
# random horizontal flip
|
156 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
157 |
+
h, w, _ = img_gt.shape
|
158 |
+
|
159 |
+
# get facial component coordinates
|
160 |
+
if self.crop_components:
|
161 |
+
locations = self.get_component_coordinates(index, status)
|
162 |
+
loc_left_eye, loc_right_eye, loc_mouth = locations
|
163 |
+
|
164 |
+
# ------------------------ generate lq image ------------------------ #
|
165 |
+
# blur
|
166 |
+
kernel = degradations.random_mixed_kernels(
|
167 |
+
self.kernel_list,
|
168 |
+
self.kernel_prob,
|
169 |
+
self.blur_kernel_size,
|
170 |
+
self.blur_sigma,
|
171 |
+
self.blur_sigma, [-math.pi, math.pi],
|
172 |
+
noise_range=None)
|
173 |
+
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
174 |
+
# downsample
|
175 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
176 |
+
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
177 |
+
# noise
|
178 |
+
if self.noise_range is not None:
|
179 |
+
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
180 |
+
# jpeg compression
|
181 |
+
if self.jpeg_range is not None:
|
182 |
+
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
183 |
+
|
184 |
+
# resize to original size
|
185 |
+
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
186 |
+
|
187 |
+
# random color jitter (only for lq)
|
188 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
189 |
+
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
190 |
+
# random to gray (only for lq)
|
191 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
192 |
+
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
193 |
+
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
194 |
+
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
195 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
196 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
197 |
+
|
198 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
199 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
200 |
+
|
201 |
+
# random color jitter (pytorch version) (only for lq)
|
202 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
203 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
204 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
205 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
206 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
207 |
+
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
208 |
+
|
209 |
+
# round and clip
|
210 |
+
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
211 |
+
|
212 |
+
# normalize
|
213 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
214 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
215 |
+
|
216 |
+
if self.crop_components:
|
217 |
+
return_dict = {
|
218 |
+
'lq': img_lq,
|
219 |
+
'gt': img_gt,
|
220 |
+
'gt_path': gt_path,
|
221 |
+
'loc_left_eye': loc_left_eye,
|
222 |
+
'loc_right_eye': loc_right_eye,
|
223 |
+
'loc_mouth': loc_mouth
|
224 |
+
}
|
225 |
+
return return_dict
|
226 |
+
else:
|
227 |
+
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
228 |
+
|
229 |
+
def __len__(self):
|
230 |
+
return len(self.paths)
|
third_part/GFPGAN/gfpgan/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import model modules for registry
|
6 |
+
# scan all the files that end with '_model.py' under the model folder
|
7 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
9 |
+
# import all the model modules
|
10 |
+
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
|
third_part/GFPGAN/gfpgan/models/gfpgan_model.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
from basicsr.archs import build_network
|
5 |
+
from basicsr.losses import build_loss
|
6 |
+
# from basicsr.losses.losses import r1_penalty
|
7 |
+
from basicsr.losses import r1_penalty
|
8 |
+
from basicsr.metrics import calculate_metric
|
9 |
+
from basicsr.models.base_model import BaseModel
|
10 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
11 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
12 |
+
from collections import OrderedDict
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torchvision.ops import roi_align
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
@MODEL_REGISTRY.register()
|
19 |
+
class GFPGANModel(BaseModel):
|
20 |
+
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
21 |
+
|
22 |
+
def __init__(self, opt):
|
23 |
+
super(GFPGANModel, self).__init__(opt)
|
24 |
+
self.idx = 0 # it is used for saving data for check
|
25 |
+
|
26 |
+
# define network
|
27 |
+
self.net_g = build_network(opt['network_g'])
|
28 |
+
self.net_g = self.model_to_device(self.net_g)
|
29 |
+
self.print_network(self.net_g)
|
30 |
+
|
31 |
+
# load pretrained model
|
32 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
33 |
+
if load_path is not None:
|
34 |
+
param_key = self.opt['path'].get('param_key_g', 'params')
|
35 |
+
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
36 |
+
|
37 |
+
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
|
38 |
+
|
39 |
+
if self.is_train:
|
40 |
+
self.init_training_settings()
|
41 |
+
|
42 |
+
def init_training_settings(self):
|
43 |
+
train_opt = self.opt['train']
|
44 |
+
|
45 |
+
# ----------- define net_d ----------- #
|
46 |
+
self.net_d = build_network(self.opt['network_d'])
|
47 |
+
self.net_d = self.model_to_device(self.net_d)
|
48 |
+
self.print_network(self.net_d)
|
49 |
+
# load pretrained model
|
50 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
51 |
+
if load_path is not None:
|
52 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
53 |
+
|
54 |
+
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
55 |
+
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
56 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
57 |
+
# load pretrained model
|
58 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
59 |
+
if load_path is not None:
|
60 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
61 |
+
else:
|
62 |
+
self.model_ema(0) # copy net_g weight
|
63 |
+
|
64 |
+
self.net_g.train()
|
65 |
+
self.net_d.train()
|
66 |
+
self.net_g_ema.eval()
|
67 |
+
|
68 |
+
# ----------- facial component networks ----------- #
|
69 |
+
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
70 |
+
self.use_facial_disc = True
|
71 |
+
else:
|
72 |
+
self.use_facial_disc = False
|
73 |
+
|
74 |
+
if self.use_facial_disc:
|
75 |
+
# left eye
|
76 |
+
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
|
77 |
+
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
|
78 |
+
self.print_network(self.net_d_left_eye)
|
79 |
+
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
|
80 |
+
if load_path is not None:
|
81 |
+
self.load_network(self.net_d_left_eye, load_path, True, 'params')
|
82 |
+
# right eye
|
83 |
+
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
|
84 |
+
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
|
85 |
+
self.print_network(self.net_d_right_eye)
|
86 |
+
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
|
87 |
+
if load_path is not None:
|
88 |
+
self.load_network(self.net_d_right_eye, load_path, True, 'params')
|
89 |
+
# mouth
|
90 |
+
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
|
91 |
+
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
|
92 |
+
self.print_network(self.net_d_mouth)
|
93 |
+
load_path = self.opt['path'].get('pretrain_network_d_mouth')
|
94 |
+
if load_path is not None:
|
95 |
+
self.load_network(self.net_d_mouth, load_path, True, 'params')
|
96 |
+
|
97 |
+
self.net_d_left_eye.train()
|
98 |
+
self.net_d_right_eye.train()
|
99 |
+
self.net_d_mouth.train()
|
100 |
+
|
101 |
+
# ----------- define facial component gan loss ----------- #
|
102 |
+
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
103 |
+
|
104 |
+
# ----------- define losses ----------- #
|
105 |
+
# pixel loss
|
106 |
+
if train_opt.get('pixel_opt'):
|
107 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
108 |
+
else:
|
109 |
+
self.cri_pix = None
|
110 |
+
|
111 |
+
# perceptual loss
|
112 |
+
if train_opt.get('perceptual_opt'):
|
113 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
114 |
+
else:
|
115 |
+
self.cri_perceptual = None
|
116 |
+
|
117 |
+
# L1 loss is used in pyramid loss, component style loss and identity loss
|
118 |
+
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
119 |
+
|
120 |
+
# gan loss (wgan)
|
121 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
122 |
+
|
123 |
+
# ----------- define identity loss ----------- #
|
124 |
+
if 'network_identity' in self.opt:
|
125 |
+
self.use_identity = True
|
126 |
+
else:
|
127 |
+
self.use_identity = False
|
128 |
+
|
129 |
+
if self.use_identity:
|
130 |
+
# define identity network
|
131 |
+
self.network_identity = build_network(self.opt['network_identity'])
|
132 |
+
self.network_identity = self.model_to_device(self.network_identity)
|
133 |
+
self.print_network(self.network_identity)
|
134 |
+
load_path = self.opt['path'].get('pretrain_network_identity')
|
135 |
+
if load_path is not None:
|
136 |
+
self.load_network(self.network_identity, load_path, True, None)
|
137 |
+
self.network_identity.eval()
|
138 |
+
for param in self.network_identity.parameters():
|
139 |
+
param.requires_grad = False
|
140 |
+
|
141 |
+
# regularization weights
|
142 |
+
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
|
143 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
144 |
+
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
|
145 |
+
self.net_d_reg_every = train_opt['net_d_reg_every']
|
146 |
+
|
147 |
+
# set up optimizers and schedulers
|
148 |
+
self.setup_optimizers()
|
149 |
+
self.setup_schedulers()
|
150 |
+
|
151 |
+
def setup_optimizers(self):
|
152 |
+
train_opt = self.opt['train']
|
153 |
+
|
154 |
+
# ----------- optimizer g ----------- #
|
155 |
+
net_g_reg_ratio = 1
|
156 |
+
normal_params = []
|
157 |
+
for _, param in self.net_g.named_parameters():
|
158 |
+
normal_params.append(param)
|
159 |
+
optim_params_g = [{ # add normal params first
|
160 |
+
'params': normal_params,
|
161 |
+
'lr': train_opt['optim_g']['lr']
|
162 |
+
}]
|
163 |
+
optim_type = train_opt['optim_g'].pop('type')
|
164 |
+
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
|
165 |
+
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
|
166 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
|
167 |
+
self.optimizers.append(self.optimizer_g)
|
168 |
+
|
169 |
+
# ----------- optimizer d ----------- #
|
170 |
+
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
|
171 |
+
normal_params = []
|
172 |
+
for _, param in self.net_d.named_parameters():
|
173 |
+
normal_params.append(param)
|
174 |
+
optim_params_d = [{ # add normal params first
|
175 |
+
'params': normal_params,
|
176 |
+
'lr': train_opt['optim_d']['lr']
|
177 |
+
}]
|
178 |
+
optim_type = train_opt['optim_d'].pop('type')
|
179 |
+
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
|
180 |
+
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
|
181 |
+
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
182 |
+
self.optimizers.append(self.optimizer_d)
|
183 |
+
|
184 |
+
# ----------- optimizers for facial component networks ----------- #
|
185 |
+
if self.use_facial_disc:
|
186 |
+
# setup optimizers for facial component discriminators
|
187 |
+
optim_type = train_opt['optim_component'].pop('type')
|
188 |
+
lr = train_opt['optim_component']['lr']
|
189 |
+
# left eye
|
190 |
+
self.optimizer_d_left_eye = self.get_optimizer(
|
191 |
+
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
|
192 |
+
self.optimizers.append(self.optimizer_d_left_eye)
|
193 |
+
# right eye
|
194 |
+
self.optimizer_d_right_eye = self.get_optimizer(
|
195 |
+
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
|
196 |
+
self.optimizers.append(self.optimizer_d_right_eye)
|
197 |
+
# mouth
|
198 |
+
self.optimizer_d_mouth = self.get_optimizer(
|
199 |
+
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
|
200 |
+
self.optimizers.append(self.optimizer_d_mouth)
|
201 |
+
|
202 |
+
def feed_data(self, data):
|
203 |
+
self.lq = data['lq'].to(self.device)
|
204 |
+
if 'gt' in data:
|
205 |
+
self.gt = data['gt'].to(self.device)
|
206 |
+
|
207 |
+
if 'loc_left_eye' in data:
|
208 |
+
# get facial component locations, shape (batch, 4)
|
209 |
+
self.loc_left_eyes = data['loc_left_eye']
|
210 |
+
self.loc_right_eyes = data['loc_right_eye']
|
211 |
+
self.loc_mouths = data['loc_mouth']
|
212 |
+
|
213 |
+
# uncomment to check data
|
214 |
+
# import torchvision
|
215 |
+
# if self.opt['rank'] == 0:
|
216 |
+
# import os
|
217 |
+
# os.makedirs('tmp/gt', exist_ok=True)
|
218 |
+
# os.makedirs('tmp/lq', exist_ok=True)
|
219 |
+
# print(self.idx)
|
220 |
+
# torchvision.utils.save_image(
|
221 |
+
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
222 |
+
# torchvision.utils.save_image(
|
223 |
+
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
224 |
+
# self.idx = self.idx + 1
|
225 |
+
|
226 |
+
def construct_img_pyramid(self):
|
227 |
+
"""Construct image pyramid for intermediate restoration loss"""
|
228 |
+
pyramid_gt = [self.gt]
|
229 |
+
down_img = self.gt
|
230 |
+
for _ in range(0, self.log_size - 3):
|
231 |
+
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
|
232 |
+
pyramid_gt.insert(0, down_img)
|
233 |
+
return pyramid_gt
|
234 |
+
|
235 |
+
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
236 |
+
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
237 |
+
eye_out_size *= face_ratio
|
238 |
+
mouth_out_size *= face_ratio
|
239 |
+
|
240 |
+
rois_eyes = []
|
241 |
+
rois_mouths = []
|
242 |
+
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
|
243 |
+
# left eye and right eye
|
244 |
+
img_inds = self.loc_left_eyes.new_full((2, 1), b)
|
245 |
+
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
|
246 |
+
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
|
247 |
+
rois_eyes.append(rois)
|
248 |
+
# mouse
|
249 |
+
img_inds = self.loc_left_eyes.new_full((1, 1), b)
|
250 |
+
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
|
251 |
+
rois_mouths.append(rois)
|
252 |
+
|
253 |
+
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
|
254 |
+
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
|
255 |
+
|
256 |
+
# real images
|
257 |
+
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
258 |
+
self.left_eyes_gt = all_eyes[0::2, :, :, :]
|
259 |
+
self.right_eyes_gt = all_eyes[1::2, :, :, :]
|
260 |
+
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
261 |
+
# output
|
262 |
+
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
263 |
+
self.left_eyes = all_eyes[0::2, :, :, :]
|
264 |
+
self.right_eyes = all_eyes[1::2, :, :, :]
|
265 |
+
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
266 |
+
|
267 |
+
def _gram_mat(self, x):
|
268 |
+
"""Calculate Gram matrix.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
torch.Tensor: Gram matrix.
|
275 |
+
"""
|
276 |
+
n, c, h, w = x.size()
|
277 |
+
features = x.view(n, c, w * h)
|
278 |
+
features_t = features.transpose(1, 2)
|
279 |
+
gram = features.bmm(features_t) / (c * h * w)
|
280 |
+
return gram
|
281 |
+
|
282 |
+
def gray_resize_for_identity(self, out, size=128):
|
283 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
284 |
+
out_gray = out_gray.unsqueeze(1)
|
285 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
286 |
+
return out_gray
|
287 |
+
|
288 |
+
def optimize_parameters(self, current_iter):
|
289 |
+
# optimize net_g
|
290 |
+
for p in self.net_d.parameters():
|
291 |
+
p.requires_grad = False
|
292 |
+
self.optimizer_g.zero_grad()
|
293 |
+
|
294 |
+
# do not update facial component net_d
|
295 |
+
if self.use_facial_disc:
|
296 |
+
for p in self.net_d_left_eye.parameters():
|
297 |
+
p.requires_grad = False
|
298 |
+
for p in self.net_d_right_eye.parameters():
|
299 |
+
p.requires_grad = False
|
300 |
+
for p in self.net_d_mouth.parameters():
|
301 |
+
p.requires_grad = False
|
302 |
+
|
303 |
+
# image pyramid loss weight
|
304 |
+
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
|
305 |
+
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
306 |
+
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
|
307 |
+
if pyramid_loss_weight > 0:
|
308 |
+
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
309 |
+
pyramid_gt = self.construct_img_pyramid()
|
310 |
+
else:
|
311 |
+
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
|
312 |
+
|
313 |
+
# get roi-align regions
|
314 |
+
if self.use_facial_disc:
|
315 |
+
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
|
316 |
+
|
317 |
+
l_g_total = 0
|
318 |
+
loss_dict = OrderedDict()
|
319 |
+
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
320 |
+
# pixel loss
|
321 |
+
if self.cri_pix:
|
322 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
323 |
+
l_g_total += l_g_pix
|
324 |
+
loss_dict['l_g_pix'] = l_g_pix
|
325 |
+
|
326 |
+
# image pyramid loss
|
327 |
+
if pyramid_loss_weight > 0:
|
328 |
+
for i in range(0, self.log_size - 2):
|
329 |
+
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
|
330 |
+
l_g_total += l_pyramid
|
331 |
+
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
|
332 |
+
|
333 |
+
# perceptual loss
|
334 |
+
if self.cri_perceptual:
|
335 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
|
336 |
+
if l_g_percep is not None:
|
337 |
+
l_g_total += l_g_percep
|
338 |
+
loss_dict['l_g_percep'] = l_g_percep
|
339 |
+
if l_g_style is not None:
|
340 |
+
l_g_total += l_g_style
|
341 |
+
loss_dict['l_g_style'] = l_g_style
|
342 |
+
|
343 |
+
# gan loss
|
344 |
+
fake_g_pred = self.net_d(self.output)
|
345 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
346 |
+
l_g_total += l_g_gan
|
347 |
+
loss_dict['l_g_gan'] = l_g_gan
|
348 |
+
|
349 |
+
# facial component loss
|
350 |
+
if self.use_facial_disc:
|
351 |
+
# left eye
|
352 |
+
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
|
353 |
+
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
|
354 |
+
l_g_total += l_g_gan
|
355 |
+
loss_dict['l_g_gan_left_eye'] = l_g_gan
|
356 |
+
# right eye
|
357 |
+
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
|
358 |
+
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
|
359 |
+
l_g_total += l_g_gan
|
360 |
+
loss_dict['l_g_gan_right_eye'] = l_g_gan
|
361 |
+
# mouth
|
362 |
+
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
|
363 |
+
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
|
364 |
+
l_g_total += l_g_gan
|
365 |
+
loss_dict['l_g_gan_mouth'] = l_g_gan
|
366 |
+
|
367 |
+
if self.opt['train'].get('comp_style_weight', 0) > 0:
|
368 |
+
# get gt feat
|
369 |
+
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
|
370 |
+
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
|
371 |
+
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
|
372 |
+
|
373 |
+
def _comp_style(feat, feat_gt, criterion):
|
374 |
+
return criterion(self._gram_mat(feat[0]), self._gram_mat(
|
375 |
+
feat_gt[0].detach())) * 0.5 + criterion(
|
376 |
+
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
|
377 |
+
|
378 |
+
# facial component style loss
|
379 |
+
comp_style_loss = 0
|
380 |
+
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
|
381 |
+
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
|
382 |
+
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
|
383 |
+
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
|
384 |
+
l_g_total += comp_style_loss
|
385 |
+
loss_dict['l_g_comp_style_loss'] = comp_style_loss
|
386 |
+
|
387 |
+
# identity loss
|
388 |
+
if self.use_identity:
|
389 |
+
identity_weight = self.opt['train']['identity_weight']
|
390 |
+
# get gray images and resize
|
391 |
+
out_gray = self.gray_resize_for_identity(self.output)
|
392 |
+
gt_gray = self.gray_resize_for_identity(self.gt)
|
393 |
+
|
394 |
+
identity_gt = self.network_identity(gt_gray).detach()
|
395 |
+
identity_out = self.network_identity(out_gray)
|
396 |
+
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
|
397 |
+
l_g_total += l_identity
|
398 |
+
loss_dict['l_identity'] = l_identity
|
399 |
+
|
400 |
+
l_g_total.backward()
|
401 |
+
self.optimizer_g.step()
|
402 |
+
|
403 |
+
# EMA
|
404 |
+
self.model_ema(decay=0.5**(32 / (10 * 1000)))
|
405 |
+
|
406 |
+
# ----------- optimize net_d ----------- #
|
407 |
+
for p in self.net_d.parameters():
|
408 |
+
p.requires_grad = True
|
409 |
+
self.optimizer_d.zero_grad()
|
410 |
+
if self.use_facial_disc:
|
411 |
+
for p in self.net_d_left_eye.parameters():
|
412 |
+
p.requires_grad = True
|
413 |
+
for p in self.net_d_right_eye.parameters():
|
414 |
+
p.requires_grad = True
|
415 |
+
for p in self.net_d_mouth.parameters():
|
416 |
+
p.requires_grad = True
|
417 |
+
self.optimizer_d_left_eye.zero_grad()
|
418 |
+
self.optimizer_d_right_eye.zero_grad()
|
419 |
+
self.optimizer_d_mouth.zero_grad()
|
420 |
+
|
421 |
+
fake_d_pred = self.net_d(self.output.detach())
|
422 |
+
real_d_pred = self.net_d(self.gt)
|
423 |
+
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
424 |
+
loss_dict['l_d'] = l_d
|
425 |
+
# In WGAN, real_score should be positive and fake_score should be negative
|
426 |
+
loss_dict['real_score'] = real_d_pred.detach().mean()
|
427 |
+
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
428 |
+
l_d.backward()
|
429 |
+
|
430 |
+
# regularization loss
|
431 |
+
if current_iter % self.net_d_reg_every == 0:
|
432 |
+
self.gt.requires_grad = True
|
433 |
+
real_pred = self.net_d(self.gt)
|
434 |
+
l_d_r1 = r1_penalty(real_pred, self.gt)
|
435 |
+
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
|
436 |
+
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
|
437 |
+
l_d_r1.backward()
|
438 |
+
|
439 |
+
self.optimizer_d.step()
|
440 |
+
|
441 |
+
# optimize facial component discriminators
|
442 |
+
if self.use_facial_disc:
|
443 |
+
# left eye
|
444 |
+
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
445 |
+
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
446 |
+
l_d_left_eye = self.cri_component(
|
447 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
448 |
+
fake_d_pred, False, is_disc=True)
|
449 |
+
loss_dict['l_d_left_eye'] = l_d_left_eye
|
450 |
+
l_d_left_eye.backward()
|
451 |
+
# right eye
|
452 |
+
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
|
453 |
+
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
|
454 |
+
l_d_right_eye = self.cri_component(
|
455 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
456 |
+
fake_d_pred, False, is_disc=True)
|
457 |
+
loss_dict['l_d_right_eye'] = l_d_right_eye
|
458 |
+
l_d_right_eye.backward()
|
459 |
+
# mouth
|
460 |
+
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
|
461 |
+
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
|
462 |
+
l_d_mouth = self.cri_component(
|
463 |
+
real_d_pred, True, is_disc=True) + self.cri_gan(
|
464 |
+
fake_d_pred, False, is_disc=True)
|
465 |
+
loss_dict['l_d_mouth'] = l_d_mouth
|
466 |
+
l_d_mouth.backward()
|
467 |
+
|
468 |
+
self.optimizer_d_left_eye.step()
|
469 |
+
self.optimizer_d_right_eye.step()
|
470 |
+
self.optimizer_d_mouth.step()
|
471 |
+
|
472 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
473 |
+
|
474 |
+
def test(self):
|
475 |
+
with torch.no_grad():
|
476 |
+
if hasattr(self, 'net_g_ema'):
|
477 |
+
self.net_g_ema.eval()
|
478 |
+
self.output, _ = self.net_g_ema(self.lq)
|
479 |
+
else:
|
480 |
+
logger = get_root_logger()
|
481 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
482 |
+
self.net_g.eval()
|
483 |
+
self.output, _ = self.net_g(self.lq)
|
484 |
+
self.net_g.train()
|
485 |
+
|
486 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
487 |
+
if self.opt['rank'] == 0:
|
488 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
489 |
+
|
490 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
491 |
+
dataset_name = dataloader.dataset.opt['name']
|
492 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
493 |
+
use_pbar = self.opt['val'].get('pbar', False)
|
494 |
+
|
495 |
+
if with_metrics:
|
496 |
+
if not hasattr(self, 'metric_results'): # only execute in the first run
|
497 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
498 |
+
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
499 |
+
self._initialize_best_metric_results(dataset_name)
|
500 |
+
# zero self.metric_results
|
501 |
+
self.metric_results = {metric: 0 for metric in self.metric_results}
|
502 |
+
|
503 |
+
metric_data = dict()
|
504 |
+
if use_pbar:
|
505 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
506 |
+
|
507 |
+
for idx, val_data in enumerate(dataloader):
|
508 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
509 |
+
self.feed_data(val_data)
|
510 |
+
self.test()
|
511 |
+
|
512 |
+
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
513 |
+
metric_data['img'] = sr_img
|
514 |
+
if hasattr(self, 'gt'):
|
515 |
+
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
516 |
+
metric_data['img2'] = gt_img
|
517 |
+
del self.gt
|
518 |
+
|
519 |
+
# tentative for out of GPU memory
|
520 |
+
del self.lq
|
521 |
+
del self.output
|
522 |
+
torch.cuda.empty_cache()
|
523 |
+
|
524 |
+
if save_img:
|
525 |
+
if self.opt['is_train']:
|
526 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
527 |
+
f'{img_name}_{current_iter}.png')
|
528 |
+
else:
|
529 |
+
if self.opt['val']['suffix']:
|
530 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
531 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
532 |
+
else:
|
533 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
534 |
+
f'{img_name}_{self.opt["name"]}.png')
|
535 |
+
imwrite(sr_img, save_img_path)
|
536 |
+
|
537 |
+
if with_metrics:
|
538 |
+
# calculate metrics
|
539 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
540 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
541 |
+
if use_pbar:
|
542 |
+
pbar.update(1)
|
543 |
+
pbar.set_description(f'Test {img_name}')
|
544 |
+
if use_pbar:
|
545 |
+
pbar.close()
|
546 |
+
|
547 |
+
if with_metrics:
|
548 |
+
for metric in self.metric_results.keys():
|
549 |
+
self.metric_results[metric] /= (idx + 1)
|
550 |
+
# update the best metric result
|
551 |
+
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
552 |
+
|
553 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
554 |
+
|
555 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
556 |
+
log_str = f'Validation {dataset_name}\n'
|
557 |
+
for metric, value in self.metric_results.items():
|
558 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
559 |
+
if hasattr(self, 'best_metric_results'):
|
560 |
+
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
561 |
+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
562 |
+
log_str += '\n'
|
563 |
+
|
564 |
+
logger = get_root_logger()
|
565 |
+
logger.info(log_str)
|
566 |
+
if tb_logger:
|
567 |
+
for metric, value in self.metric_results.items():
|
568 |
+
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
569 |
+
|
570 |
+
def save(self, epoch, current_iter):
|
571 |
+
# save net_g and net_d
|
572 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
573 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
574 |
+
# save component discriminators
|
575 |
+
if self.use_facial_disc:
|
576 |
+
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
577 |
+
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
578 |
+
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
579 |
+
# save training state
|
580 |
+
self.save_training_state(epoch, current_iter)
|
third_part/GFPGAN/gfpgan/train.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import os.path as osp
|
3 |
+
from basicsr.train import train_pipeline
|
4 |
+
|
5 |
+
import gfpgan.archs
|
6 |
+
import gfpgan.data
|
7 |
+
import gfpgan.models
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
11 |
+
train_pipeline(root_path)
|
third_part/GFPGAN/gfpgan/utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from basicsr.utils import img2tensor, tensor2img
|
5 |
+
from basicsr.utils.download_util import load_file_from_url
|
6 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
7 |
+
from torchvision.transforms.functional import normalize
|
8 |
+
|
9 |
+
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
10 |
+
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
11 |
+
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
12 |
+
|
13 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
14 |
+
|
15 |
+
|
16 |
+
class GFPGANer():
|
17 |
+
"""Helper for restoration with GFPGAN.
|
18 |
+
|
19 |
+
It will detect and crop faces, and then resize the faces to 512x512.
|
20 |
+
GFPGAN is used to restored the resized faces.
|
21 |
+
The background is upsampled with the bg_upsampler.
|
22 |
+
Finally, the faces will be pasted back to the upsample background image.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
26 |
+
upscale (float): The upscale of the final output. Default: 2.
|
27 |
+
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
28 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
29 |
+
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
33 |
+
self.upscale = upscale
|
34 |
+
self.bg_upsampler = bg_upsampler
|
35 |
+
|
36 |
+
# initialize model
|
37 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
38 |
+
# initialize the GFP-GAN
|
39 |
+
if arch == 'clean':
|
40 |
+
self.gfpgan = GFPGANv1Clean(
|
41 |
+
out_size=512,
|
42 |
+
num_style_feat=512,
|
43 |
+
channel_multiplier=channel_multiplier,
|
44 |
+
decoder_load_path=None,
|
45 |
+
fix_decoder=False,
|
46 |
+
num_mlp=8,
|
47 |
+
input_is_latent=True,
|
48 |
+
different_w=True,
|
49 |
+
narrow=1,
|
50 |
+
sft_half=True)
|
51 |
+
elif arch == 'bilinear':
|
52 |
+
self.gfpgan = GFPGANBilinear(
|
53 |
+
out_size=512,
|
54 |
+
num_style_feat=512,
|
55 |
+
channel_multiplier=channel_multiplier,
|
56 |
+
decoder_load_path=None,
|
57 |
+
fix_decoder=False,
|
58 |
+
num_mlp=8,
|
59 |
+
input_is_latent=True,
|
60 |
+
different_w=True,
|
61 |
+
narrow=1,
|
62 |
+
sft_half=True)
|
63 |
+
elif arch == 'original':
|
64 |
+
self.gfpgan = GFPGANv1(
|
65 |
+
out_size=512,
|
66 |
+
num_style_feat=512,
|
67 |
+
channel_multiplier=channel_multiplier,
|
68 |
+
decoder_load_path=None,
|
69 |
+
fix_decoder=True,
|
70 |
+
num_mlp=8,
|
71 |
+
input_is_latent=True,
|
72 |
+
different_w=True,
|
73 |
+
narrow=1,
|
74 |
+
sft_half=True)
|
75 |
+
# initialize face helper
|
76 |
+
self.face_helper = FaceRestoreHelper(
|
77 |
+
upscale,
|
78 |
+
face_size=512,
|
79 |
+
crop_ratio=(1, 1),
|
80 |
+
det_model='retinaface_resnet50',
|
81 |
+
save_ext='png',
|
82 |
+
device=self.device)
|
83 |
+
|
84 |
+
if model_path.startswith('https://'):
|
85 |
+
model_path = load_file_from_url(
|
86 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
87 |
+
loadnet = torch.load(model_path)
|
88 |
+
if 'params_ema' in loadnet:
|
89 |
+
keyname = 'params_ema'
|
90 |
+
else:
|
91 |
+
keyname = 'params'
|
92 |
+
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
93 |
+
self.gfpgan.eval()
|
94 |
+
self.gfpgan = self.gfpgan.to(self.device)
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
98 |
+
self.face_helper.clean_all()
|
99 |
+
|
100 |
+
if has_aligned: # the inputs are already aligned
|
101 |
+
img = cv2.resize(img, (512, 512))
|
102 |
+
self.face_helper.cropped_faces = [img]
|
103 |
+
else:
|
104 |
+
self.face_helper.read_image(img)
|
105 |
+
# get face landmarks for each face
|
106 |
+
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
107 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
108 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
109 |
+
# align and warp each face
|
110 |
+
self.face_helper.align_warp_face()
|
111 |
+
|
112 |
+
# face restoration
|
113 |
+
for cropped_face in self.face_helper.cropped_faces:
|
114 |
+
# prepare data
|
115 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
116 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
117 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
118 |
+
|
119 |
+
try:
|
120 |
+
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
|
121 |
+
# convert to image
|
122 |
+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
123 |
+
except RuntimeError as error:
|
124 |
+
print(f'\tFailed inference for GFPGAN: {error}.')
|
125 |
+
restored_face = cropped_face
|
126 |
+
|
127 |
+
restored_face = restored_face.astype('uint8')
|
128 |
+
self.face_helper.add_restored_face(restored_face)
|
129 |
+
|
130 |
+
if not has_aligned and paste_back:
|
131 |
+
# upsample the background
|
132 |
+
if self.bg_upsampler is not None:
|
133 |
+
# Now only support RealESRGAN for upsampling background
|
134 |
+
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
135 |
+
else:
|
136 |
+
bg_img = None
|
137 |
+
|
138 |
+
self.face_helper.get_inverse_affine(None)
|
139 |
+
# paste each restored face to the input image
|
140 |
+
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
141 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
142 |
+
else:
|
143 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
third_part/GFPGAN/gfpgan/version.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GENERATED VERSION FILE
|
2 |
+
# TIME: Wed Apr 20 14:43:06 2022
|
3 |
+
__version__ = '1.3.2'
|
4 |
+
__gitsha__ = '924ce47'
|
5 |
+
version_info = (1, 3, 2)
|
third_part/GFPGAN/gfpgan/weights/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Weights
|
2 |
+
|
3 |
+
Put the downloaded weights to this folder.
|
third_part/GFPGAN/options/train_gfpgan_v1.yml
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: train_GFPGANv1_512
|
3 |
+
model_type: GFPGANModel
|
4 |
+
num_gpu: auto # officially, we use 4 GPUs
|
5 |
+
manual_seed: 0
|
6 |
+
|
7 |
+
# dataset and data loader settings
|
8 |
+
datasets:
|
9 |
+
train:
|
10 |
+
name: FFHQ
|
11 |
+
type: FFHQDegradationDataset
|
12 |
+
# dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
13 |
+
dataroot_gt: datasets/ffhq/ffhq_512
|
14 |
+
io_backend:
|
15 |
+
# type: lmdb
|
16 |
+
type: disk
|
17 |
+
|
18 |
+
use_hflip: true
|
19 |
+
mean: [0.5, 0.5, 0.5]
|
20 |
+
std: [0.5, 0.5, 0.5]
|
21 |
+
out_size: 512
|
22 |
+
|
23 |
+
blur_kernel_size: 41
|
24 |
+
kernel_list: ['iso', 'aniso']
|
25 |
+
kernel_prob: [0.5, 0.5]
|
26 |
+
blur_sigma: [0.1, 10]
|
27 |
+
downsample_range: [0.8, 8]
|
28 |
+
noise_range: [0, 20]
|
29 |
+
jpeg_range: [60, 100]
|
30 |
+
|
31 |
+
# color jitter and gray
|
32 |
+
color_jitter_prob: 0.3
|
33 |
+
color_jitter_shift: 20
|
34 |
+
color_jitter_pt_prob: 0.3
|
35 |
+
gray_prob: 0.01
|
36 |
+
|
37 |
+
# If you do not want colorization, please set
|
38 |
+
# color_jitter_prob: ~
|
39 |
+
# color_jitter_pt_prob: ~
|
40 |
+
# gray_prob: 0.01
|
41 |
+
# gt_gray: True
|
42 |
+
|
43 |
+
crop_components: true
|
44 |
+
component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
45 |
+
eye_enlarge_ratio: 1.4
|
46 |
+
|
47 |
+
# data loader
|
48 |
+
use_shuffle: true
|
49 |
+
num_worker_per_gpu: 6
|
50 |
+
batch_size_per_gpu: 3
|
51 |
+
dataset_enlarge_ratio: 1
|
52 |
+
prefetch_mode: ~
|
53 |
+
|
54 |
+
val:
|
55 |
+
# Please modify accordingly to use your own validation
|
56 |
+
# Or comment the val block if do not need validation during training
|
57 |
+
name: validation
|
58 |
+
type: PairedImageDataset
|
59 |
+
dataroot_lq: datasets/faces/validation/input
|
60 |
+
dataroot_gt: datasets/faces/validation/reference
|
61 |
+
io_backend:
|
62 |
+
type: disk
|
63 |
+
mean: [0.5, 0.5, 0.5]
|
64 |
+
std: [0.5, 0.5, 0.5]
|
65 |
+
scale: 1
|
66 |
+
|
67 |
+
# network structures
|
68 |
+
network_g:
|
69 |
+
type: GFPGANv1
|
70 |
+
out_size: 512
|
71 |
+
num_style_feat: 512
|
72 |
+
channel_multiplier: 1
|
73 |
+
resample_kernel: [1, 3, 3, 1]
|
74 |
+
decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
|
75 |
+
fix_decoder: true
|
76 |
+
num_mlp: 8
|
77 |
+
lr_mlp: 0.01
|
78 |
+
input_is_latent: true
|
79 |
+
different_w: true
|
80 |
+
narrow: 1
|
81 |
+
sft_half: true
|
82 |
+
|
83 |
+
network_d:
|
84 |
+
type: StyleGAN2Discriminator
|
85 |
+
out_size: 512
|
86 |
+
channel_multiplier: 1
|
87 |
+
resample_kernel: [1, 3, 3, 1]
|
88 |
+
|
89 |
+
network_d_left_eye:
|
90 |
+
type: FacialComponentDiscriminator
|
91 |
+
|
92 |
+
network_d_right_eye:
|
93 |
+
type: FacialComponentDiscriminator
|
94 |
+
|
95 |
+
network_d_mouth:
|
96 |
+
type: FacialComponentDiscriminator
|
97 |
+
|
98 |
+
network_identity:
|
99 |
+
type: ResNetArcFace
|
100 |
+
block: IRBlock
|
101 |
+
layers: [2, 2, 2, 2]
|
102 |
+
use_se: False
|
103 |
+
|
104 |
+
# path
|
105 |
+
path:
|
106 |
+
pretrain_network_g: ~
|
107 |
+
param_key_g: params_ema
|
108 |
+
strict_load_g: ~
|
109 |
+
pretrain_network_d: ~
|
110 |
+
pretrain_network_d_left_eye: ~
|
111 |
+
pretrain_network_d_right_eye: ~
|
112 |
+
pretrain_network_d_mouth: ~
|
113 |
+
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
114 |
+
# resume
|
115 |
+
resume_state: ~
|
116 |
+
ignore_resume_networks: ['network_identity']
|
117 |
+
|
118 |
+
# training settings
|
119 |
+
train:
|
120 |
+
optim_g:
|
121 |
+
type: Adam
|
122 |
+
lr: !!float 2e-3
|
123 |
+
optim_d:
|
124 |
+
type: Adam
|
125 |
+
lr: !!float 2e-3
|
126 |
+
optim_component:
|
127 |
+
type: Adam
|
128 |
+
lr: !!float 2e-3
|
129 |
+
|
130 |
+
scheduler:
|
131 |
+
type: MultiStepLR
|
132 |
+
milestones: [600000, 700000]
|
133 |
+
gamma: 0.5
|
134 |
+
|
135 |
+
total_iter: 800000
|
136 |
+
warmup_iter: -1 # no warm up
|
137 |
+
|
138 |
+
# losses
|
139 |
+
# pixel loss
|
140 |
+
pixel_opt:
|
141 |
+
type: L1Loss
|
142 |
+
loss_weight: !!float 1e-1
|
143 |
+
reduction: mean
|
144 |
+
# L1 loss used in pyramid loss, component style loss and identity loss
|
145 |
+
L1_opt:
|
146 |
+
type: L1Loss
|
147 |
+
loss_weight: 1
|
148 |
+
reduction: mean
|
149 |
+
|
150 |
+
# image pyramid loss
|
151 |
+
pyramid_loss_weight: 1
|
152 |
+
remove_pyramid_loss: 50000
|
153 |
+
# perceptual loss (content and style losses)
|
154 |
+
perceptual_opt:
|
155 |
+
type: PerceptualLoss
|
156 |
+
layer_weights:
|
157 |
+
# before relu
|
158 |
+
'conv1_2': 0.1
|
159 |
+
'conv2_2': 0.1
|
160 |
+
'conv3_4': 1
|
161 |
+
'conv4_4': 1
|
162 |
+
'conv5_4': 1
|
163 |
+
vgg_type: vgg19
|
164 |
+
use_input_norm: true
|
165 |
+
perceptual_weight: !!float 1
|
166 |
+
style_weight: 50
|
167 |
+
range_norm: true
|
168 |
+
criterion: l1
|
169 |
+
# gan loss
|
170 |
+
gan_opt:
|
171 |
+
type: GANLoss
|
172 |
+
gan_type: wgan_softplus
|
173 |
+
loss_weight: !!float 1e-1
|
174 |
+
# r1 regularization for discriminator
|
175 |
+
r1_reg_weight: 10
|
176 |
+
# facial component loss
|
177 |
+
gan_component_opt:
|
178 |
+
type: GANLoss
|
179 |
+
gan_type: vanilla
|
180 |
+
real_label_val: 1.0
|
181 |
+
fake_label_val: 0.0
|
182 |
+
loss_weight: !!float 1
|
183 |
+
comp_style_weight: 200
|
184 |
+
# identity loss
|
185 |
+
identity_weight: 10
|
186 |
+
|
187 |
+
net_d_iters: 1
|
188 |
+
net_d_init_iters: 0
|
189 |
+
net_d_reg_every: 16
|
190 |
+
|
191 |
+
# validation settings
|
192 |
+
val:
|
193 |
+
val_freq: !!float 5e3
|
194 |
+
save_img: true
|
195 |
+
|
196 |
+
metrics:
|
197 |
+
psnr: # metric name
|
198 |
+
type: calculate_psnr
|
199 |
+
crop_border: 0
|
200 |
+
test_y_channel: false
|
201 |
+
|
202 |
+
# logging settings
|
203 |
+
logger:
|
204 |
+
print_freq: 100
|
205 |
+
save_checkpoint_freq: !!float 5e3
|
206 |
+
use_tb_logger: true
|
207 |
+
wandb:
|
208 |
+
project: ~
|
209 |
+
resume_id: ~
|
210 |
+
|
211 |
+
# dist training settings
|
212 |
+
dist_params:
|
213 |
+
backend: nccl
|
214 |
+
port: 29500
|
215 |
+
|
216 |
+
find_unused_parameters: true
|