LB5 commited on
Commit
22b8701
1 Parent(s): e6a22e6

Upload 45 files

Browse files
Files changed (46) hide show
  1. .gitattributes +1 -0
  2. Docs/pics/img_mask_blur_21.jpg +0 -0
  3. Docs/pics/img_mask_blur_41.jpg +0 -0
  4. Docs/pics/img_mask_blur_61.jpg +0 -0
  5. Docs/pics/img_mask_erode_0.jpg +0 -0
  6. Docs/pics/img_mask_erode_20.jpg +0 -0
  7. Docs/pics/img_mask_erode_40.jpg +0 -0
  8. README.md +178 -12
  9. app.py +74 -0
  10. app_web.py +160 -0
  11. configs/run_image.yaml +35 -0
  12. configs/run_image_specific.yaml +35 -0
  13. configs/run_video.yaml +36 -0
  14. configs/run_video_specific.yaml +36 -0
  15. demo_file/Iron_man.jpg +0 -0
  16. demo_file/multi_people.jpg +0 -0
  17. demo_file/multi_people_1080p.mp4 +3 -0
  18. demo_file/multispecific/DST_01.jpg +0 -0
  19. demo_file/multispecific/DST_02.jpg +0 -0
  20. demo_file/multispecific/DST_03.jpg +0 -0
  21. demo_file/multispecific/SRC_01.png +0 -0
  22. demo_file/multispecific/SRC_02.png +0 -0
  23. demo_file/multispecific/SRC_03.png +0 -0
  24. demo_file/specific1.png +0 -0
  25. demo_file/specific2.png +0 -0
  26. demo_file/specific3.png +0 -0
  27. requirements.txt +9 -0
  28. src/Blend/blend.py +12 -0
  29. src/DataManager/ImageDataManager.py +42 -0
  30. src/DataManager/VideoDataManager.py +73 -0
  31. src/DataManager/base.py +16 -0
  32. src/DataManager/utils.py +12 -0
  33. src/FaceAlign/face_align.py +244 -0
  34. src/FaceDetector/face_detector.py +37 -0
  35. src/FaceId/faceid.py +50 -0
  36. src/Generator/fs_networks_512.py +277 -0
  37. src/Generator/fs_networks_fix.py +245 -0
  38. src/Misc/types.py +11 -0
  39. src/Misc/utils.py +28 -0
  40. src/PostProcess/GFPGAN/gfpgan.py +341 -0
  41. src/PostProcess/GFPGAN/stylegan2.py +351 -0
  42. src/PostProcess/ParsingModel/model.py +323 -0
  43. src/PostProcess/ParsingModel/resnet.py +109 -0
  44. src/PostProcess/utils.py +122 -0
  45. src/model_loader.py +106 -0
  46. src/simswap.py +322 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  simswap-inference-pytorch-main/demo_file/multi_people_1080p.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  simswap-inference-pytorch-main/demo_file/multi_people_1080p.mp4 filter=lfs diff=lfs merge=lfs -text
36
+ demo_file/multi_people_1080p.mp4 filter=lfs diff=lfs merge=lfs -text
Docs/pics/img_mask_blur_21.jpg ADDED
Docs/pics/img_mask_blur_41.jpg ADDED
Docs/pics/img_mask_blur_61.jpg ADDED
Docs/pics/img_mask_erode_0.jpg ADDED
Docs/pics/img_mask_erode_20.jpg ADDED
Docs/pics/img_mask_erode_40.jpg ADDED
README.md CHANGED
@@ -1,12 +1,178 @@
1
- ---
2
- title: Simswap55
3
- emoji: 🌖
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unofficial Pytorch implementation (**inference only**) of the SimSwap: An Efficient Framework For High Fidelity Face Swapping
2
+
3
+ ## Updates
4
+ - improved performance (up to 40% in some scenarios, it depends on frame resolution and number of swaps per frame).
5
+ - fixed a problem with overlapped areas from close faces (https://github.com/mike9251/simswap-inference-pytorch/issues/21)
6
+ - added support for using GFPGAN model as an additional post-processing step to improve final image quality
7
+ - added a toy gui app. Might be useful to understand how different pipeline settings affect output
8
+
9
+ ## Attention
10
+ ***This project is for technical and academic use only. Please do not apply it to illegal and unethical scenarios.***
11
+
12
+ ***In the event of violation of the legal and ethical requirements of the user's country or region, this code repository is exempt from liability.***
13
+
14
+ ## Preparation
15
+ ### Installation
16
+ ```
17
+ # clone project
18
+ git clone https://github.com/mike9251/simswap-inference-pytorch
19
+ cd simswap-inference-pytorch
20
+
21
+ # [OPTIONAL] create conda environment
22
+ conda create -n myenv python=3.9
23
+ conda activate myenv
24
+
25
+ # install pytorch and torchvision according to instructions
26
+ # https://pytorch.org/get-started/
27
+
28
+ # install requirements
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ ### Important
33
+ Face detection will be performed on CPU. To run it on GPU you need to install onnx gpu runtime:
34
+
35
+ ```pip install onnxruntime-gpu==1.11.1```
36
+
37
+ and modify one line of code in ```...Anaconda3\envs\myenv\Lib\site-packages\insightface\model_zoo\model_zoo.py```
38
+
39
+ Here, instead of passing **None** as the second argument to the onnx inference session
40
+ ```angular2html
41
+ class ModelRouter:
42
+ def __init__(self, onnx_file):
43
+ self.onnx_file = onnx_file
44
+
45
+ def get_model(self):
46
+ session = onnxruntime.InferenceSession(self.onnx_file, None)
47
+ input_cfg = session.get_inputs()[0]
48
+ ```
49
+ pass a list of providers
50
+ ```angular2html
51
+ class ModelRouter:
52
+ def __init__(self, onnx_file):
53
+ self.onnx_file = onnx_file
54
+
55
+ def get_model(self):
56
+ session = onnxruntime.InferenceSession(self.onnx_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
57
+ input_cfg = session.get_inputs()[0]
58
+ ```
59
+ Otherwise simply use CPU onnx runtime with only a minor performance drop.
60
+
61
+ ### Weights
62
+ #### Weights for all models get downloaded automatically.
63
+
64
+ You can also download weights manually and put inside `weights` folder:
65
+
66
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx">face_detector_scrfd_10g_bnkps.onnx</a>
67
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit">arcface_net.jit</a>
68
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth">79999_iter.pth</a>
69
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth">simswap_224_latest_net_G.pth</a> - official 224x224 model
70
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth">simswap_512_390000_net_G.pth</a> - unofficial 512x512 model (I took it <a href="https://github.com/neuralchen/SimSwap/issues/255">here</a>).
71
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth">GFPGANv1.4_ema.pth</a>
72
+ - weights/<a href="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit">blend_module.jit</a>
73
+
74
+ ## Inference
75
+ ### Web App
76
+ ```angular2html
77
+ streamlit run app_web.py
78
+ ```
79
+
80
+ ### Command line App
81
+ This repository supports inference in several modes, which can be easily configured with config files in the **configs** folder.
82
+ - **replace all faces on a target image / folder with images**
83
+ ```angular2html
84
+ python app.py --config-name=run_image.yaml
85
+ ```
86
+
87
+ - **replace all faces on a video**
88
+ ```angular2html
89
+ python app.py --config-name=run_video.yaml
90
+ ```
91
+
92
+ - **replace a specific face on a target image / folder with images**
93
+ ```angular2html
94
+ python app.py --config-name=run_image_specific.yaml
95
+ ```
96
+
97
+ - **replace a specific face on a video**
98
+ ```angular2html
99
+ python app.py --config-name=run_video_specific.yaml
100
+ ```
101
+
102
+ Config files contain two main parts:
103
+
104
+ - **data**
105
+ - *id_image* - source image, identity of this person will be transferred.
106
+ - *att_image* - target image, attributes of the person on this image will be mixed with the person's identity from the source image. Here you can also specify a folder with multiple images - identity translation will be applied to all images in the folder.
107
+ - *specific_id_image* - a specific person on the *att_image* you would like to replace, leaving others untouched (if there's any other person).
108
+ - *att_video* - the same as *att_image*
109
+ - *clean_work_dir* - whether remove temp folder with images or not (for video configs only).
110
+
111
+
112
+ - **pipeline**
113
+ - *face_detector_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
114
+ - *face_id_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
115
+ - *parsing_model_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
116
+ - *simswap_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
117
+ - *gfpgan_weights* - path to the weights file OR an empty string ("") for automatic weights downloading.
118
+ - *device* - whether you want to run the application using GPU or CPU.
119
+ - *crop_size* - size of images SimSwap models works with.
120
+ - *checkpoint_type* - the official model works with 224x224 crops and has different pre/post processings (imagenet like). Latest official repository allows you to train your own models, but the architecture and pre/post processings are slightly different (1. removed Tanh from the last layer; 2. normalization to [0...1] range). **If you run the official 224x224 model then set this parameter to "official_224", otherwise "none".**
121
+ - *face_alignment_type* - affects reference face key points coordinates. **Possible values are "ffhq" and "none". Try both of them to see which one works better for your data.**
122
+ - *smooth_mask_kernel_size* - a non-zero value. It's used for the post-processing mask size attenuation. You might want to play with this parameter.
123
+ - *smooth_mask_iter* - a non-zero value. The number of times a face mask is smoothed.
124
+ - *smooth_mask_threshold* - controls the face mask saturation. Valid values are in range [0.0...1.0]. Tune this parameter if there are artifacts around swapped faces.
125
+ - *face_detector_threshold* - values in range [0.0...1.0]. Higher value reduces probability of FP detections but increases the probability of FN.
126
+ - *specific_latent_match_threshold* - values in range [0.0...inf]. Usually takes small values around 0.05.
127
+ - *enhance_output* - whether to apply GFPGAN model or not as a post-processing step.
128
+
129
+
130
+ ### Overriding parameters with CMD
131
+ Every parameter in a config file can be overridden by specifying it directly with CMD. For example:
132
+
133
+ ```angular2html
134
+ python app.py --config-name=run_image.yaml data.specific_id_image="path/to/the/image" pipeline.erosion_kernel_size=20
135
+ ```
136
+
137
+ ## Video
138
+
139
+ <details>
140
+ <summary><b>Official 224x224 model, face alignment "none"</b></summary>
141
+
142
+ [![Video](https://i.imgur.com/iCujdRB.jpg)](https://vimeo.com/728346715)
143
+
144
+ </details>
145
+
146
+ <details>
147
+ <summary><b>Official 224x224 model, face alignment "ffhq"</b></summary>
148
+
149
+ [![Video](https://i.imgur.com/48hjJO4.jpg)](https://vimeo.com/728348520)
150
+
151
+ </details>
152
+
153
+ <details>
154
+ <summary><b>Unofficial 512x512 model, face alignment "none"</b></summary>
155
+
156
+ [![Video](https://i.imgur.com/rRltD4U.jpg)](https://vimeo.com/728346542)
157
+
158
+ </details>
159
+
160
+ <details>
161
+ <summary><b>Unofficial 512x512 model, face alignment "ffhq"</b></summary>
162
+
163
+ [![Video](https://i.imgur.com/gFkpyXS.jpg)](https://vimeo.com/728349219)
164
+
165
+ </details>
166
+
167
+ ## License
168
+ For academic and non-commercial use only.The whole project is under the CC-BY-NC 4.0 license. See [LICENSE](https://github.com/neuralchen/SimSwap/blob/main/LICENSE) for additional details.
169
+
170
+ ## Acknowledgements
171
+
172
+ <!--ts-->
173
+ * [SimSwap](https://github.com/neuralchen/SimSwap)
174
+ * [Insightface](https://github.com/deepinsight/insightface)
175
+ * [Face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)
176
+ * [BiSeNet](https://github.com/CoinCheung/BiSeNet)
177
+ * [GFPGAN](https://github.com/TencentARC/GFPGAN)
178
+ <!--te-->
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+ from tqdm import tqdm
4
+
5
+ import hydra
6
+ from omegaconf import DictConfig
7
+ import numpy as np
8
+
9
+ from src.simswap import SimSwap
10
+ from src.DataManager.ImageDataManager import ImageDataManager
11
+ from src.DataManager.VideoDataManager import VideoDataManager
12
+ from src.DataManager.utils import imread_rgb
13
+
14
+
15
+ class Application:
16
+ def __init__(self, config: DictConfig):
17
+
18
+ id_image_path = Path(config.data.id_image)
19
+ specific_id_image_path = Path(config.data.specific_id_image)
20
+ att_image_path = Path(config.data.att_image)
21
+ att_video_path = Path(config.data.att_video)
22
+ output_dir = Path(config.data.output_dir)
23
+
24
+ assert id_image_path.exists(), f"Can't find {id_image_path} file!"
25
+
26
+ self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path)
27
+ self.specific_id_image: Optional[np.ndarray] = (
28
+ imread_rgb(specific_id_image_path)
29
+ if specific_id_image_path and specific_id_image_path.is_file()
30
+ else None
31
+ )
32
+
33
+ self.att_image: Optional[ImageDataManager] = None
34
+ if att_image_path and (att_image_path.is_file() or att_image_path.is_dir()):
35
+ self.att_image: Optional[ImageDataManager] = ImageDataManager(
36
+ src_data=att_image_path, output_dir=output_dir
37
+ )
38
+
39
+ self.att_video: Optional[VideoDataManager] = None
40
+ if att_video_path and att_video_path.is_file():
41
+ self.att_video: Optional[VideoDataManager] = VideoDataManager(
42
+ src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir
43
+ )
44
+
45
+ assert not (self.att_video and self.att_image), "Only one attribute source can be used!"
46
+
47
+ self.data_manager = self.att_video if self.att_video else self.att_image
48
+
49
+ self.model = SimSwap(
50
+ config=config.pipeline,
51
+ id_image=self.id_image,
52
+ specific_image=self.specific_id_image,
53
+ )
54
+
55
+ def run(self):
56
+ for _ in tqdm(range(len(self.data_manager))):
57
+
58
+ att_img = self.data_manager.get()
59
+
60
+ output = self.model(att_img)
61
+
62
+ self.data_manager.save(output)
63
+
64
+
65
+ @hydra.main(config_path="configs/", config_name="run_image.yaml")
66
+ def main(config: DictConfig):
67
+
68
+ app = Application(config)
69
+
70
+ app.run()
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
app_web.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from collections import namedtuple
5
+ import numpy as np
6
+
7
+ from src.simswap import SimSwap
8
+
9
+
10
+ def run(model):
11
+ id_image = None
12
+ attr_image = None
13
+ specific_image = None
14
+ output = None
15
+
16
+ def get_np_image(file):
17
+ return np.array(Image.open(file))[:, :, :3]
18
+
19
+ with st.sidebar:
20
+ uploaded_file = st.file_uploader("Select an ID image")
21
+ if uploaded_file is not None:
22
+ id_image = get_np_image(uploaded_file)
23
+
24
+ uploaded_file = st.file_uploader("Select an Attribute image")
25
+ if uploaded_file is not None:
26
+ attr_image = get_np_image(uploaded_file)
27
+
28
+ uploaded_file = st.file_uploader("Select a specific person image (Optional)")
29
+ if uploaded_file is not None:
30
+ specific_image = get_np_image(uploaded_file)
31
+
32
+ face_alignment_type = st.radio("Face alignment type:", ("none", "ffhq"))
33
+
34
+ enhance_output = st.radio("Enhance output:", ("yes", "no"))
35
+
36
+ smooth_mask_iter = st.slider(
37
+ label="smooth_mask_iter", min_value=1, max_value=60, step=1, value=7
38
+ )
39
+
40
+ smooth_mask_kernel_size = st.slider(
41
+ label="smooth_mask_kernel_size", min_value=1, max_value=61, step=2, value=17
42
+ )
43
+
44
+ smooth_mask_threshold = st.slider(label="smooth_mask_threshold", min_value=0.01, max_value=1.0, step=0.01, value=0.9)
45
+
46
+ specific_latent_match_threshold = st.slider(
47
+ label="specific_latent_match_threshold",
48
+ min_value=0.0,
49
+ max_value=10.0,
50
+ value=0.05,
51
+ )
52
+
53
+ num_cols = sum(
54
+ (id_image is not None, attr_image is not None, specific_image is not None)
55
+ )
56
+ cols = st.columns(num_cols if num_cols > 0 else 1)
57
+ i = 0
58
+
59
+ if id_image is not None:
60
+ with cols[i]:
61
+ i += 1
62
+ st.header("ID image")
63
+ st.image(id_image)
64
+
65
+ if attr_image is not None:
66
+ with cols[i]:
67
+ i += 1
68
+ st.header("Attribute image")
69
+ st.image(attr_image)
70
+
71
+ if specific_image is not None:
72
+ with cols[i]:
73
+ st.header("Specific image")
74
+ st.image(specific_image)
75
+
76
+ if id_image is not None and attr_image is not None:
77
+ model.set_face_alignment_type(face_alignment_type)
78
+ model.set_smooth_mask_iter(smooth_mask_iter)
79
+ model.set_smooth_mask_kernel_size(smooth_mask_kernel_size)
80
+ model.set_smooth_mask_threshold(smooth_mask_threshold)
81
+ model.set_specific_latent_match_threshold(specific_latent_match_threshold)
82
+ model.enhance_output = True if enhance_output == "yes" else False
83
+
84
+ model.specific_latent = None
85
+ model.specific_id_image = specific_image if specific_image is not None else None
86
+
87
+ model.id_latent = None
88
+ model.id_image = id_image
89
+
90
+ output = model(attr_image)
91
+
92
+ if output is not None:
93
+ with st.container():
94
+ st.header("SimSwap output")
95
+ st.image(output)
96
+
97
+ output_to_download = Image.fromarray(output.astype("uint8"), "RGB")
98
+ buf = BytesIO()
99
+ output_to_download.save(buf, format="JPEG")
100
+
101
+ st.download_button(
102
+ label="Download",
103
+ data=buf.getvalue(),
104
+ file_name="output.jpg",
105
+ mime="image/jpeg",
106
+ )
107
+
108
+
109
+ @st.cache(allow_output_mutation=True)
110
+ def load_model(config):
111
+ return SimSwap(
112
+ config=config,
113
+ id_image=None,
114
+ specific_image=None,
115
+ )
116
+
117
+
118
+ # TODO: remove it and use config files from 'configs'
119
+ Config = namedtuple(
120
+ "Config",
121
+ "face_detector_weights"
122
+ + " face_id_weights"
123
+ + " parsing_model_weights"
124
+ + " simswap_weights"
125
+ + " gfpgan_weights"
126
+ + " blend_module_weights"
127
+ + " device"
128
+ + " crop_size"
129
+ + " checkpoint_type"
130
+ + " face_alignment_type"
131
+ + " smooth_mask_iter"
132
+ + " smooth_mask_kernel_size"
133
+ + " smooth_mask_threshold"
134
+ + " face_detector_threshold"
135
+ + " specific_latent_match_threshold"
136
+ + " enhance_output",
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ config = Config(
141
+ face_detector_weights="weights/scrfd_10g_bnkps.onnx",
142
+ face_id_weights="weights/arcface_net.jit",
143
+ parsing_model_weights="weights/79999_iter.pth",
144
+ simswap_weights="weights/latest_net_G.pth",
145
+ gfpgan_weights="weights/GFPGANv1.4_ema.pth",
146
+ blend_module_weights="weights/blend.jit",
147
+ device="cuda",
148
+ crop_size=224,
149
+ checkpoint_type="official_224",
150
+ face_alignment_type="none",
151
+ smooth_mask_iter=7,
152
+ smooth_mask_kernel_size=17,
153
+ smooth_mask_threshold=0.9,
154
+ face_detector_threshold=0.6,
155
+ specific_latent_match_threshold=0.05,
156
+ enhance_output=True
157
+ )
158
+
159
+ model = load_model(config)
160
+ run(model)
configs/run_image.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
3
+ att_image: "${hydra:runtime.cwd}/demo_file/multi_people.jpg"
4
+ specific_id_image: "none"
5
+ att_video: "none"
6
+ output_dir: ${hydra:runtime.cwd}/output
7
+
8
+ pipeline:
9
+ face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
10
+ face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
11
+ parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
12
+ simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
13
+ gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
14
+ blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
15
+ device: "cuda"
16
+ crop_size: 224
17
+ # it seems that the official 224 checkpoint works better with 'none' face alignment type
18
+ checkpoint_type: "official_224" #"none"
19
+ face_alignment_type: "none" #"ffhq"
20
+ smooth_mask_iter: 7
21
+ smooth_mask_kernel_size: 17
22
+ smooth_mask_threshold: 0.9
23
+ face_detector_threshold: 0.6
24
+ specific_latent_match_threshold: 0.05
25
+ enhance_output: True
26
+
27
+ defaults:
28
+ - _self_
29
+ - override hydra/hydra_logging: disabled
30
+ - override hydra/job_logging: disabled
31
+
32
+ hydra:
33
+ output_subdir: null
34
+ run:
35
+ dir: .
configs/run_image_specific.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
3
+ att_image: "${hydra:runtime.cwd}/demo_file/multi_people.jpg"
4
+ specific_id_image: "${hydra:runtime.cwd}/demo_file/specific1.png"
5
+ att_video: "none"
6
+ output_dir: ${hydra:runtime.cwd}/output
7
+
8
+ pipeline:
9
+ face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
10
+ face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
11
+ parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
12
+ simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
13
+ gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
14
+ blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
15
+ device: "cuda"
16
+ crop_size: 224
17
+ # it seems that the official 224 checkpoint works better with 'none' face alignment type
18
+ checkpoint_type: "official_224" #"none"
19
+ face_alignment_type: "none" #"ffhq"
20
+ smooth_mask_iter: 7
21
+ smooth_mask_kernel_size: 17
22
+ smooth_mask_threshold: 0.9
23
+ face_detector_threshold: 0.6
24
+ specific_latent_match_threshold: 0.05
25
+ enhance_output: True
26
+
27
+ defaults:
28
+ - _self_
29
+ - override hydra/hydra_logging: disabled
30
+ - override hydra/job_logging: disabled
31
+
32
+ hydra:
33
+ output_subdir: null
34
+ run:
35
+ dir: .
configs/run_video.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
3
+ att_image: "none"
4
+ specific_id_image: "none"
5
+ att_video: "${hydra:runtime.cwd}/demo_file/multi_people_1080p.mp4"
6
+ output_dir: ${hydra:runtime.cwd}/output
7
+ clean_work_dir: True
8
+
9
+ pipeline:
10
+ face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
11
+ face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
12
+ parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
13
+ simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
14
+ gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
15
+ blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
16
+ device: "cuda"
17
+ crop_size: 224
18
+ # it seems that the official 224 checkpoint works better with 'none' face alignment type
19
+ checkpoint_type: "official_224" #"none"
20
+ face_alignment_type: "none" #"ffhq"
21
+ smooth_mask_iter: 7
22
+ smooth_mask_kernel_size: 17
23
+ smooth_mask_threshold: 0.9
24
+ face_detector_threshold: 0.6
25
+ specific_latent_match_threshold: 0.05
26
+ enhance_output: True
27
+
28
+ defaults:
29
+ - _self_
30
+ - override hydra/hydra_logging: disabled
31
+ - override hydra/job_logging: disabled
32
+
33
+ hydra:
34
+ output_subdir: null
35
+ run:
36
+ dir: .
configs/run_video_specific.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ id_image: "${hydra:runtime.cwd}/demo_file/Iron_man.jpg"
3
+ att_image: "none"
4
+ specific_id_image: "${hydra:runtime.cwd}/demo_file/specific1.png"
5
+ att_video: "${hydra:runtime.cwd}/demo_file/multi_people_1080p.mp4"
6
+ output_dir: ${hydra:runtime.cwd}/output
7
+ clean_work_dir: True
8
+
9
+ pipeline:
10
+ face_detector_weights: "${hydra:runtime.cwd}/weights/face_detector_scrfd_10g_bnkps.onnx"
11
+ face_id_weights: "${hydra:runtime.cwd}/weights/arcface_net.jit"
12
+ parsing_model_weights: "${hydra:runtime.cwd}/weights/79999_iter.pth"
13
+ simswap_weights: "${hydra:runtime.cwd}/weights/simswap_224_latest_net_G.pth"
14
+ gfpgan_weights: "${hydra:runtime.cwd}/weights/GFPGANv1.4_ema.pth"
15
+ blend_module_weights: "${hydra:runtime.cwd}/weights/blend_module.jit"
16
+ device: "cuda"
17
+ crop_size: 224
18
+ # it seems that the official 224 checkpoint works better with 'none' face alignment type
19
+ checkpoint_type: "official_224" #"none"
20
+ face_alignment_type: "none" #"ffhq"
21
+ smooth_mask_iter: 7
22
+ smooth_mask_kernel_size: 17
23
+ smooth_mask_threshold: 0.9
24
+ face_detector_threshold: 0.6
25
+ specific_latent_match_threshold: 0.05
26
+ enhance_output: True
27
+
28
+ defaults:
29
+ - _self_
30
+ - override hydra/hydra_logging: disabled
31
+ - override hydra/job_logging: disabled
32
+
33
+ hydra:
34
+ output_subdir: null
35
+ run:
36
+ dir: .
demo_file/Iron_man.jpg ADDED
demo_file/multi_people.jpg ADDED
demo_file/multi_people_1080p.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97fe960cc03abac34509ec69a68c7b75f2ca1325aea353456411fe7569d978e1
3
+ size 8735410
demo_file/multispecific/DST_01.jpg ADDED
demo_file/multispecific/DST_02.jpg ADDED
demo_file/multispecific/DST_03.jpg ADDED
demo_file/multispecific/SRC_01.png ADDED
demo_file/multispecific/SRC_02.png ADDED
demo_file/multispecific/SRC_03.png ADDED
demo_file/specific1.png ADDED
demo_file/specific2.png ADDED
demo_file/specific3.png ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ hydra-core>=1.1.0
2
+ insightface==0.2.1
3
+ kornia==0.6.5
4
+ moviepy==1.0.3
5
+ onnx==1.12.0
6
+ onnxruntime==1.11.1
7
+ opencv-python==4.6.0.66
8
+ tqdm==4.64.0
9
+ streamlit==1.14.0
src/Blend/blend.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class BlendModule(nn.Module):
6
+ def __init__(self, model_path, device):
7
+ super().__init__()
8
+
9
+ self.model = torch.jit.load(model_path).to(device)
10
+
11
+ def forward(self, swap, mask, att_img):
12
+ return self.model(swap, mask, att_img)
src/DataManager/ImageDataManager.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.DataManager.base import BaseDataManager
2
+ from src.DataManager.utils import imread_rgb, imwrite_rgb
3
+
4
+ import numpy as np
5
+ from pathlib import Path
6
+
7
+
8
+ class ImageDataManager(BaseDataManager):
9
+ def __init__(self, src_data: Path, output_dir: Path):
10
+ self.output_dir: Path = output_dir
11
+ self.output_dir.mkdir(exist_ok=True)
12
+ self.output_dir = output_dir / "img"
13
+ self.output_dir.mkdir(exist_ok=True)
14
+
15
+ self.data_paths = []
16
+ if src_data.is_file():
17
+ self.data_paths.append(src_data)
18
+ elif src_data.is_dir():
19
+ self.data_paths = (
20
+ list(src_data.glob("*.jpg"))
21
+ + list(src_data.glob("*.jpeg"))
22
+ + list(src_data.glob("*.png"))
23
+ )
24
+
25
+ assert len(self.data_paths), "Data must be supplied!"
26
+
27
+ self.data_paths_iter = iter(self.data_paths)
28
+
29
+ self.last_idx = -1
30
+
31
+ def __len__(self):
32
+ return len(self.data_paths)
33
+
34
+ def get(self) -> np.ndarray:
35
+ img_path = next(self.data_paths_iter)
36
+ self.last_idx += 1
37
+ return imread_rgb(img_path)
38
+
39
+ def save(self, img: np.ndarray):
40
+ filename = "swap_" + Path(self.data_paths[self.last_idx]).name
41
+
42
+ imwrite_rgb(self.output_dir / filename, img)
src/DataManager/VideoDataManager.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.DataManager.base import BaseDataManager
2
+ from src.DataManager.utils import imwrite_rgb
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from pathlib import Path
7
+ import shutil
8
+ from typing import Optional, Union
9
+
10
+ from moviepy.editor import AudioFileClip, VideoFileClip
11
+ from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
12
+
13
+
14
+ class VideoDataManager(BaseDataManager):
15
+ def __init__(self, src_data: Path, output_dir: Path, clean_work_dir: bool = False):
16
+ self.video_handle: Optional[cv2.VideoCapture] = None
17
+ self.audio_handle: Optional[AudioFileClip] = None
18
+
19
+ self.output_dir = output_dir
20
+ self.output_img_dir = output_dir / "img"
21
+ self.output_dir.mkdir(exist_ok=True)
22
+ self.output_img_dir.mkdir(exist_ok=True)
23
+ self.video_name = None
24
+ self.clean_work_dir = clean_work_dir
25
+
26
+ if src_data.is_file():
27
+ self.video_name = "swap_" + src_data.name
28
+
29
+ if VideoFileClip(str(src_data)).audio is not None:
30
+ self.audio_handle = AudioFileClip(str(src_data))
31
+
32
+ self.video_handle = cv2.VideoCapture(str(src_data))
33
+ self.video_handle.set(cv2.CAP_PROP_POS_FRAMES, 0)
34
+
35
+ self.frame_count = int(self.video_handle.get(cv2.CAP_PROP_FRAME_COUNT))
36
+ self.fps = self.video_handle.get(cv2.CAP_PROP_FPS)
37
+
38
+ self.last_idx = -1
39
+
40
+ assert self.video_handle, "Video file must be specified!"
41
+
42
+ def __len__(self):
43
+ return self.frame_count
44
+
45
+ def get(self) -> np.ndarray:
46
+ img: Union[None, np.ndarray] = None
47
+
48
+ while img is None and self.last_idx < self.frame_count:
49
+ status, img = self.video_handle.read()
50
+ self.last_idx += 1
51
+
52
+ if img is not None:
53
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
54
+ return img
55
+
56
+ def save(self, img: np.ndarray):
57
+ filename = "frame_{:0>7d}.jpg".format(self.last_idx)
58
+ imwrite_rgb(self.output_img_dir / filename, img)
59
+
60
+ if (self.frame_count - 1) == self.last_idx:
61
+ self._close()
62
+
63
+ def _close(self):
64
+ image_filenames = [str(x) for x in sorted(self.output_img_dir.glob("*.jpg"))]
65
+ clip = ImageSequenceClip(image_filenames, fps=self.fps)
66
+
67
+ if self.audio_handle is not None:
68
+ clip = clip.set_audio(self.audio_handle)
69
+
70
+ clip.write_videofile(str(self.output_dir / self.video_name))
71
+
72
+ if self.clean_work_dir:
73
+ shutil.rmtree(self.output_img_dir, ignore_errors=True)
src/DataManager/base.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import numpy as np
3
+
4
+
5
+ class BaseDataManager(ABC):
6
+ @abstractmethod
7
+ def __len__(self) -> int:
8
+ pass
9
+
10
+ @abstractmethod
11
+ def get(self) -> np.ndarray:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def save(self, img: np.ndarray) -> None:
16
+ pass
src/DataManager/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+
7
+ def imread_rgb(img_path: Union[str, Path]) -> np.ndarray:
8
+ return cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
9
+
10
+
11
+ def imwrite_rgb(img_path: Union[str, Path], img):
12
+ return cv2.imwrite(str(img_path), cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
src/FaceAlign/face_align.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from skimage import transform as skt
5
+ from typing import Iterable, Tuple
6
+
7
+ src1 = np.array(
8
+ [
9
+ [51.642, 50.115],
10
+ [57.617, 49.990],
11
+ [35.740, 69.007],
12
+ [51.157, 89.050],
13
+ [57.025, 89.702],
14
+ ],
15
+ dtype=np.float32,
16
+ )
17
+ # <--left
18
+ src2 = np.array(
19
+ [
20
+ [45.031, 50.118],
21
+ [65.568, 50.872],
22
+ [39.677, 68.111],
23
+ [45.177, 86.190],
24
+ [64.246, 86.758],
25
+ ],
26
+ dtype=np.float32,
27
+ )
28
+
29
+ # ---frontal
30
+ src3 = np.array(
31
+ [
32
+ [39.730, 51.138],
33
+ [72.270, 51.138],
34
+ [56.000, 68.493],
35
+ [42.463, 87.010],
36
+ [69.537, 87.010],
37
+ ],
38
+ dtype=np.float32,
39
+ )
40
+
41
+ # -->right
42
+ src4 = np.array(
43
+ [
44
+ [46.845, 50.872],
45
+ [67.382, 50.118],
46
+ [72.737, 68.111],
47
+ [48.167, 86.758],
48
+ [67.236, 86.190],
49
+ ],
50
+ dtype=np.float32,
51
+ )
52
+
53
+ # -->right profile
54
+ src5 = np.array(
55
+ [
56
+ [54.796, 49.990],
57
+ [60.771, 50.115],
58
+ [76.673, 69.007],
59
+ [55.388, 89.702],
60
+ [61.257, 89.050],
61
+ ],
62
+ dtype=np.float32,
63
+ )
64
+
65
+ src = np.array([src1, src2, src3, src4, src5])
66
+ src_map = src
67
+
68
+ ffhq_src = np.array(
69
+ [
70
+ [192.98138, 239.94708],
71
+ [318.90277, 240.1936],
72
+ [256.63416, 314.01935],
73
+ [201.26117, 371.41043],
74
+ [313.08905, 371.15118],
75
+ ]
76
+ )
77
+ ffhq_src = np.expand_dims(ffhq_src, axis=0)
78
+
79
+
80
+ # arcface_src = np.array(
81
+ # [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
82
+ # [41.5493, 92.3655], [70.7299, 92.2041]],
83
+ # dtype=np.float32)
84
+
85
+ # arcface_src = np.expand_dims(arcface_src, axis=0)
86
+
87
+ # In[66]:
88
+
89
+
90
+ # lmk is prediction; src is template
91
+ def estimate_norm(lmk, image_size=112, mode="ffhq"):
92
+ assert lmk.shape == (5, 2)
93
+ tform = skt.SimilarityTransform()
94
+ lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
95
+ min_M = []
96
+ min_index = []
97
+ min_error = float("inf")
98
+ if mode == "ffhq":
99
+ # assert image_size == 112
100
+ src = ffhq_src * image_size / 512
101
+ else:
102
+ src = src_map * image_size / 112
103
+ for i in np.arange(src.shape[0]):
104
+ tform.estimate(lmk, src[i])
105
+ M = tform.params[0:2, :]
106
+ results = np.dot(M, lmk_tran.T)
107
+ results = results.T
108
+ error = np.sum(np.sqrt(np.sum((results - src[i]) ** 2, axis=1)))
109
+ if error < min_error:
110
+ min_error = error
111
+ min_M = M
112
+ min_index = i
113
+ return min_M, min_index
114
+
115
+
116
+ def norm_crop(img, landmark, image_size=112, mode="ffhq"):
117
+ if mode == "Both":
118
+ M_None, _ = estimate_norm(landmark, image_size, mode="newarc")
119
+ M_ffhq, _ = estimate_norm(landmark, image_size, mode="ffhq")
120
+ warped_None = cv2.warpAffine(
121
+ img, M_None, (image_size, image_size), borderValue=0.0
122
+ )
123
+ warped_ffhq = cv2.warpAffine(
124
+ img, M_ffhq, (image_size, image_size), borderValue=0.0
125
+ )
126
+ return warped_ffhq, warped_None
127
+ else:
128
+ M, pose_index = estimate_norm(landmark, image_size, mode)
129
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
130
+ return warped
131
+
132
+
133
+ def square_crop(im, S):
134
+ if im.shape[0] > im.shape[1]:
135
+ height = S
136
+ width = int(float(im.shape[1]) / im.shape[0] * S)
137
+ scale = float(S) / im.shape[0]
138
+ else:
139
+ width = S
140
+ height = int(float(im.shape[0]) / im.shape[1] * S)
141
+ scale = float(S) / im.shape[1]
142
+ resized_im = cv2.resize(im, (width, height))
143
+ det_im = np.zeros((S, S, 3), dtype=np.uint8)
144
+ det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im
145
+ return det_im, scale
146
+
147
+
148
+ def transform(data, center, output_size, scale, rotation):
149
+ scale_ratio = scale
150
+ rot = float(rotation) * np.pi / 180.0
151
+ # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
152
+ t1 = skt.SimilarityTransform(scale=scale_ratio)
153
+ cx = center[0] * scale_ratio
154
+ cy = center[1] * scale_ratio
155
+ t2 = skt.SimilarityTransform(translation=(-1 * cx, -1 * cy))
156
+ t3 = skt.SimilarityTransform(rotation=rot)
157
+ t4 = skt.SimilarityTransform(translation=(output_size / 2, output_size / 2))
158
+ t = t1 + t2 + t3 + t4
159
+ M = t.params[0:2]
160
+ cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0)
161
+ return cropped, M
162
+
163
+
164
+ def trans_points2d(pts, M):
165
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
166
+ for i in range(pts.shape[0]):
167
+ pt = pts[i]
168
+ new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
169
+ new_pt = np.dot(M, new_pt)
170
+ # print('new_pt', new_pt.shape, new_pt)
171
+ new_pts[i] = new_pt[0:2]
172
+
173
+ return new_pts
174
+
175
+
176
+ def trans_points3d(pts, M):
177
+ scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
178
+ # print(scale)
179
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
180
+ for i in range(pts.shape[0]):
181
+ pt = pts[i]
182
+ new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
183
+ new_pt = np.dot(M, new_pt)
184
+ # print('new_pt', new_pt.shape, new_pt)
185
+ new_pts[i][0:2] = new_pt[0:2]
186
+ new_pts[i][2] = pts[i][2] * scale
187
+
188
+ return new_pts
189
+
190
+
191
+ def trans_points(pts, M):
192
+ if pts.shape[1] == 2:
193
+ return trans_points2d(pts, M)
194
+ else:
195
+ return trans_points3d(pts, M)
196
+
197
+
198
+ def inverse_transform(mat: np.ndarray) -> np.ndarray:
199
+ # inverse the Affine transformation matrix
200
+ inv_mat = np.zeros([2, 3])
201
+ div1 = mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0]
202
+ inv_mat[0][0] = mat[1][1] / div1
203
+ inv_mat[0][1] = -mat[0][1] / div1
204
+ inv_mat[0][2] = -(mat[0][2] * mat[1][1] - mat[0][1] * mat[1][2]) / div1
205
+ div2 = mat[0][1] * mat[1][0] - mat[0][0] * mat[1][1]
206
+ inv_mat[1][0] = mat[1][0] / div2
207
+ inv_mat[1][1] = -mat[0][0] / div2
208
+ inv_mat[1][2] = -(mat[0][2] * mat[1][0] - mat[0][0] * mat[1][2]) / div2
209
+ return inv_mat
210
+
211
+
212
+ def inverse_transform_batch(mat: torch.Tensor) -> torch.Tensor:
213
+ # inverse the Affine transformation matrix
214
+ inv_mat = torch.zeros_like(mat)
215
+ div1 = mat[:, 0, 0] * mat[:, 1, 1] - mat[:, 0, 1] * mat[:, 1, 0]
216
+ inv_mat[:, 0, 0] = mat[:, 1, 1] / div1
217
+ inv_mat[:, 0, 1] = -mat[:, 0, 1] / div1
218
+ inv_mat[:, 0, 2] = (
219
+ -(mat[:, 0, 2] * mat[:, 1, 1] - mat[:, 0, 1] * mat[:, 1, 2]) / div1
220
+ )
221
+ div2 = mat[:, 0, 1] * mat[:, 1, 0] - mat[:, 0, 0] * mat[:, 1, 1]
222
+ inv_mat[:, 1, 0] = mat[:, 1, 0] / div2
223
+ inv_mat[:, 1, 1] = -mat[:, 0, 0] / div2
224
+ inv_mat[:, 1, 2] = (
225
+ -(mat[:, 0, 2] * mat[:, 1, 0] - mat[:, 0, 0] * mat[:, 1, 2]) / div2
226
+ )
227
+ return inv_mat
228
+
229
+
230
+ def align_face(
231
+ img: np.ndarray, key_points: np.ndarray, crop_size: int, mode: str = "ffhq"
232
+ ) -> Tuple[Iterable[np.ndarray], Iterable[np.ndarray]]:
233
+ align_imgs = []
234
+ transforms = []
235
+ for i in range(key_points.shape[0]):
236
+ kps = key_points[i]
237
+ transform_matrix, _ = estimate_norm(kps, crop_size, mode=mode)
238
+ align_img = cv2.warpAffine(
239
+ img, transform_matrix, (crop_size, crop_size), borderValue=0.0
240
+ )
241
+ align_imgs.append(align_img)
242
+ transforms.append(transform_matrix)
243
+
244
+ return align_imgs, transforms
src/FaceDetector/face_detector.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple, Optional, Tuple
2
+
3
+ from insightface.model_zoo import model_zoo
4
+ import numpy as np
5
+ from pathlib import Path
6
+
7
+
8
+ class Detection(NamedTuple):
9
+ bbox: Optional[np.ndarray]
10
+ score: Optional[np.ndarray]
11
+ key_points: Optional[np.ndarray]
12
+
13
+
14
+ class FaceDetector:
15
+ def __init__(
16
+ self,
17
+ model_path: Path,
18
+ det_thresh: float = 0.5,
19
+ det_size: Tuple[int, int] = (640, 640),
20
+ mode: str = "None",
21
+ device: str = "cpu",
22
+ ):
23
+ self.det_thresh = det_thresh
24
+ self.mode = mode
25
+ self.device = device
26
+ self.handler = model_zoo.get_model(str(model_path))
27
+ ctx_id = -1 if device == "cpu" else 0
28
+ self.handler.prepare(ctx_id, input_size=det_size)
29
+
30
+ def __call__(self, img: np.ndarray, max_num: int = 0) -> Detection:
31
+ bboxes, kpss = self.handler.detect(
32
+ img, threshold=self.det_thresh, max_num=max_num, metric="default"
33
+ )
34
+ if bboxes.shape[0] == 0:
35
+ return Detection(None, None, None)
36
+
37
+ return Detection(bboxes[..., :-1], bboxes[..., -1], kpss)
src/FaceId/faceid.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+
7
+ from typing import Iterable, Union
8
+ from pathlib import Path
9
+
10
+
11
+ class FaceId(torch.nn.Module):
12
+ def __init__(
13
+ self, model_path: Path, device: str, input_shape: Iterable[int] = (112, 112)
14
+ ):
15
+ super().__init__()
16
+
17
+ self.input_shape = input_shape
18
+ self.net = torch.load(model_path, map_location=torch.device("cpu"))
19
+ self.net.eval()
20
+
21
+ self.transform = transforms.Compose(
22
+ [
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
+ ]
26
+ )
27
+
28
+ for n, p in self.net.named_parameters():
29
+ assert (
30
+ not p.requires_grad
31
+ ), f"Parameter {n}: requires_grad: {p.requires_grad}"
32
+
33
+ self.device = torch.device(device)
34
+ self.to(self.device)
35
+
36
+ def forward(
37
+ self, img_id: Union[np.ndarray, Iterable[np.ndarray]], normalize: bool = True
38
+ ) -> torch.Tensor:
39
+ if isinstance(img_id, Iterable):
40
+ img_id = [self.transform(x) for x in img_id]
41
+ img_id = torch.stack(img_id, dim=0)
42
+ else:
43
+ img_id = self.transform(img_id)
44
+ img_id = img_id.unsqueeze(0)
45
+
46
+ img_id = img_id.to(self.device)
47
+
48
+ img_id_112 = F.interpolate(img_id, size=self.input_shape)
49
+ latent_id = self.net(img_id_112)
50
+ return F.normalize(latent_id, p=2, dim=1) if normalize else latent_id
src/Generator/fs_networks_512.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Naiyuan liu
3
+ Github: https://github.com/NNNNAI
4
+ Date: 2021-11-23 16:55:48
5
+ LastEditors: Naiyuan liu
6
+ LastEditTime: 2021-11-24 16:58:06
7
+ Description:
8
+ """
9
+ """
10
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
11
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
12
+ """
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+
17
+ class InstanceNorm(nn.Module):
18
+ def __init__(self, epsilon=1e-8):
19
+ """
20
+ @notice: avoid in-place ops.
21
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
22
+ """
23
+ super(InstanceNorm, self).__init__()
24
+ self.epsilon = epsilon
25
+
26
+ def forward(self, x):
27
+ x = x - torch.mean(x, (2, 3), True)
28
+ tmp = torch.mul(x, x) # or x ** 2
29
+ tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
30
+ return x * tmp
31
+
32
+
33
+ class ApplyStyle(nn.Module):
34
+ """
35
+ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
36
+ """
37
+
38
+ def __init__(self, latent_size, channels):
39
+ super(ApplyStyle, self).__init__()
40
+ self.linear = nn.Linear(latent_size, channels * 2)
41
+
42
+ def forward(self, x, latent):
43
+ style = self.linear(latent) # style => [batch_size, n_channels*2]
44
+ shape = [-1, 2, x.size(1), 1, 1]
45
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
46
+ # x = x * (style[:, 0] + 1.) + style[:, 1]
47
+ x = x * (style[:, 0] * 1 + 1.0) + style[:, 1] * 1
48
+ return x
49
+
50
+
51
+ class ResnetBlock_Adain(nn.Module):
52
+ def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
53
+ super(ResnetBlock_Adain, self).__init__()
54
+
55
+ p = 0
56
+ conv1 = []
57
+ if padding_type == "reflect":
58
+ conv1 += [nn.ReflectionPad2d(1)]
59
+ elif padding_type == "replicate":
60
+ conv1 += [nn.ReplicationPad2d(1)]
61
+ elif padding_type == "zero":
62
+ p = 1
63
+ else:
64
+ raise NotImplementedError("padding [%s] is not implemented" % padding_type)
65
+ conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
66
+ self.conv1 = nn.Sequential(*conv1)
67
+ self.style1 = ApplyStyle(latent_size, dim)
68
+ self.act1 = activation
69
+
70
+ p = 0
71
+ conv2 = []
72
+ if padding_type == "reflect":
73
+ conv2 += [nn.ReflectionPad2d(1)]
74
+ elif padding_type == "replicate":
75
+ conv2 += [nn.ReplicationPad2d(1)]
76
+ elif padding_type == "zero":
77
+ p = 1
78
+ else:
79
+ raise NotImplementedError("padding [%s] is not implemented" % padding_type)
80
+ conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
81
+ self.conv2 = nn.Sequential(*conv2)
82
+ self.style2 = ApplyStyle(latent_size, dim)
83
+
84
+ def forward(self, x, dlatents_in_slice):
85
+ y = self.conv1(x)
86
+ y = self.style1(y, dlatents_in_slice)
87
+ y = self.act1(y)
88
+ y = self.conv2(y)
89
+ y = self.style2(y, dlatents_in_slice)
90
+ out = x + y
91
+ return out
92
+
93
+
94
+ class Generator_Adain_Upsample(nn.Module):
95
+ def __init__(
96
+ self,
97
+ input_nc,
98
+ output_nc,
99
+ latent_size,
100
+ n_blocks=6,
101
+ deep=False,
102
+ norm_layer=nn.BatchNorm2d,
103
+ padding_type="reflect",
104
+ ):
105
+ assert n_blocks >= 0
106
+ super(Generator_Adain_Upsample, self).__init__()
107
+ activation = nn.ReLU(True)
108
+ self.deep = deep
109
+
110
+ self.first_layer = nn.Sequential(
111
+ nn.ReflectionPad2d(3),
112
+ nn.Conv2d(input_nc, 32, kernel_size=7, padding=0),
113
+ norm_layer(32),
114
+ activation,
115
+ )
116
+ # downsample
117
+ self.down0 = nn.Sequential(
118
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
119
+ norm_layer(64),
120
+ activation,
121
+ )
122
+ self.down1 = nn.Sequential(
123
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
124
+ norm_layer(128),
125
+ activation,
126
+ )
127
+ self.down2 = nn.Sequential(
128
+ nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
129
+ norm_layer(256),
130
+ activation,
131
+ )
132
+ self.down3 = nn.Sequential(
133
+ nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
134
+ norm_layer(512),
135
+ activation,
136
+ )
137
+ if self.deep:
138
+ self.down4 = nn.Sequential(
139
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
140
+ norm_layer(512),
141
+ activation,
142
+ )
143
+
144
+ # resnet blocks
145
+ BN = []
146
+ for i in range(n_blocks):
147
+ BN += [
148
+ ResnetBlock_Adain(
149
+ 512,
150
+ latent_size=latent_size,
151
+ padding_type=padding_type,
152
+ activation=activation,
153
+ )
154
+ ]
155
+ self.BottleNeck = nn.Sequential(*BN)
156
+
157
+ if self.deep:
158
+ self.up4 = nn.Sequential(
159
+ nn.Upsample(scale_factor=2, mode="bilinear"),
160
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
161
+ nn.BatchNorm2d(512),
162
+ activation,
163
+ )
164
+ self.up3 = nn.Sequential(
165
+ nn.Upsample(scale_factor=2, mode="bilinear"),
166
+ nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
167
+ nn.BatchNorm2d(256),
168
+ activation,
169
+ )
170
+ self.up2 = nn.Sequential(
171
+ nn.Upsample(scale_factor=2, mode="bilinear"),
172
+ nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
173
+ nn.BatchNorm2d(128),
174
+ activation,
175
+ )
176
+ self.up1 = nn.Sequential(
177
+ nn.Upsample(scale_factor=2, mode="bilinear"),
178
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
179
+ nn.BatchNorm2d(64),
180
+ activation,
181
+ )
182
+ self.up0 = nn.Sequential(
183
+ nn.Upsample(scale_factor=2, mode="bilinear"),
184
+ nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
185
+ nn.BatchNorm2d(32),
186
+ activation,
187
+ )
188
+ self.last_layer = nn.Sequential(
189
+ nn.ReflectionPad2d(3),
190
+ nn.Conv2d(32, output_nc, kernel_size=7, padding=0),
191
+ nn.Tanh(),
192
+ )
193
+
194
+ def forward(self, input, dlatents):
195
+ x = input # 3*224*224
196
+
197
+ skip0 = self.first_layer(x)
198
+ skip1 = self.down0(skip0)
199
+ skip2 = self.down1(skip1)
200
+ skip3 = self.down2(skip2)
201
+ if self.deep:
202
+ skip4 = self.down3(skip3)
203
+ x = self.down4(skip4)
204
+ else:
205
+ x = self.down3(skip3)
206
+
207
+ for i in range(len(self.BottleNeck)):
208
+ x = self.BottleNeck[i](x, dlatents)
209
+
210
+ if self.deep:
211
+ x = self.up4(x)
212
+ x = self.up3(x)
213
+ x = self.up2(x)
214
+ x = self.up1(x)
215
+ x = self.up0(x)
216
+ x = self.last_layer(x)
217
+ x = (x + 1) / 2
218
+
219
+ return x
220
+
221
+
222
+ class Discriminator(nn.Module):
223
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
224
+ super(Discriminator, self).__init__()
225
+
226
+ kw = 4
227
+ padw = 1
228
+ self.down1 = nn.Sequential(
229
+ nn.Conv2d(input_nc, 64, kernel_size=kw, stride=2, padding=padw),
230
+ nn.LeakyReLU(0.2, True),
231
+ )
232
+ self.down2 = nn.Sequential(
233
+ nn.Conv2d(64, 128, kernel_size=kw, stride=2, padding=padw),
234
+ norm_layer(128),
235
+ nn.LeakyReLU(0.2, True),
236
+ )
237
+ self.down3 = nn.Sequential(
238
+ nn.Conv2d(128, 256, kernel_size=kw, stride=2, padding=padw),
239
+ norm_layer(256),
240
+ nn.LeakyReLU(0.2, True),
241
+ )
242
+ self.down4 = nn.Sequential(
243
+ nn.Conv2d(256, 512, kernel_size=kw, stride=2, padding=padw),
244
+ norm_layer(512),
245
+ nn.LeakyReLU(0.2, True),
246
+ )
247
+ self.conv1 = nn.Sequential(
248
+ nn.Conv2d(512, 512, kernel_size=kw, stride=1, padding=padw),
249
+ norm_layer(512),
250
+ nn.LeakyReLU(0.2, True),
251
+ )
252
+
253
+ if use_sigmoid:
254
+ self.conv2 = nn.Sequential(
255
+ nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()
256
+ )
257
+ else:
258
+ self.conv2 = nn.Sequential(
259
+ nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw)
260
+ )
261
+
262
+ def forward(self, input):
263
+ out = []
264
+ x = self.down1(input)
265
+ out.append(x)
266
+ x = self.down2(x)
267
+ out.append(x)
268
+ x = self.down3(x)
269
+ out.append(x)
270
+ x = self.down4(x)
271
+ out.append(x)
272
+ x = self.conv1(x)
273
+ out.append(x)
274
+ x = self.conv2(x)
275
+ out.append(x)
276
+
277
+ return out
src/Generator/fs_networks_fix.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import transforms
9
+
10
+ from typing import Iterable
11
+ import numpy as np
12
+
13
+
14
+ class InstanceNorm(nn.Module):
15
+ def __init__(self, epsilon=1e-8):
16
+ """
17
+ @notice: avoid in-place ops.
18
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
19
+ """
20
+ super(InstanceNorm, self).__init__()
21
+ self.epsilon = epsilon
22
+
23
+ def forward(self, x):
24
+ x = x - torch.mean(x, (2, 3), True)
25
+ tmp = torch.mul(x, x) # or x ** 2
26
+ tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
27
+ return x * tmp
28
+
29
+
30
+ class ApplyStyle(nn.Module):
31
+ """
32
+ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
33
+ """
34
+
35
+ def __init__(self, latent_size, channels):
36
+ super(ApplyStyle, self).__init__()
37
+ self.linear = nn.Linear(latent_size, channels * 2)
38
+
39
+ def forward(self, x, latent):
40
+ style = self.linear(latent) # style => [batch_size, n_channels*2]
41
+ shape = [-1, 2, x.size(1), 1, 1]
42
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
43
+ # x = x * (style[:, 0] + 1.) + style[:, 1]
44
+ x = x * (style[:, 0] * 1 + 1.0) + style[:, 1] * 1
45
+ return x
46
+
47
+
48
+ class ResnetBlock_Adain(nn.Module):
49
+ def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
50
+ super(ResnetBlock_Adain, self).__init__()
51
+
52
+ p = 0
53
+ conv1 = []
54
+ if padding_type == "reflect":
55
+ conv1 += [nn.ReflectionPad2d(1)]
56
+ elif padding_type == "replicate":
57
+ conv1 += [nn.ReplicationPad2d(1)]
58
+ elif padding_type == "zero":
59
+ p = 1
60
+ else:
61
+ raise NotImplementedError("padding [%s] is not implemented" % padding_type)
62
+ conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
63
+ self.conv1 = nn.Sequential(*conv1)
64
+ self.style1 = ApplyStyle(latent_size, dim)
65
+ self.act1 = activation
66
+
67
+ p = 0
68
+ conv2 = []
69
+ if padding_type == "reflect":
70
+ conv2 += [nn.ReflectionPad2d(1)]
71
+ elif padding_type == "replicate":
72
+ conv2 += [nn.ReplicationPad2d(1)]
73
+ elif padding_type == "zero":
74
+ p = 1
75
+ else:
76
+ raise NotImplementedError("padding [%s] is not implemented" % padding_type)
77
+ conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
78
+ self.conv2 = nn.Sequential(*conv2)
79
+ self.style2 = ApplyStyle(latent_size, dim)
80
+
81
+ def forward(self, x, dlatents_in_slice):
82
+ y = self.conv1(x)
83
+ y = self.style1(y, dlatents_in_slice)
84
+ y = self.act1(y)
85
+ y = self.conv2(y)
86
+ y = self.style2(y, dlatents_in_slice)
87
+ out = x + y
88
+ return out
89
+
90
+
91
+ class Generator_Adain_Upsample(nn.Module):
92
+ def __init__(
93
+ self,
94
+ input_nc: int,
95
+ output_nc: int,
96
+ latent_size: int,
97
+ n_blocks: int = 6,
98
+ deep: bool = False,
99
+ use_last_act: bool = True,
100
+ norm_layer: torch.nn.Module = nn.BatchNorm2d,
101
+ padding_type: str = "reflect",
102
+ ):
103
+ assert n_blocks >= 0
104
+ super(Generator_Adain_Upsample, self).__init__()
105
+
106
+ activation = nn.ReLU(True)
107
+
108
+ self.deep = deep
109
+ self.use_last_act = use_last_act
110
+
111
+ self.to_tensor_normalize = transforms.Compose(
112
+ [
113
+ transforms.ToTensor(),
114
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
115
+ ]
116
+ )
117
+
118
+ self.to_tensor = transforms.Compose([transforms.ToTensor()])
119
+
120
+ self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
121
+ self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
122
+
123
+ self.first_layer = nn.Sequential(
124
+ nn.ReflectionPad2d(3),
125
+ nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
126
+ norm_layer(64),
127
+ activation,
128
+ )
129
+ # downsample
130
+ self.down1 = nn.Sequential(
131
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
132
+ norm_layer(128),
133
+ activation,
134
+ )
135
+ self.down2 = nn.Sequential(
136
+ nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
137
+ norm_layer(256),
138
+ activation,
139
+ )
140
+ self.down3 = nn.Sequential(
141
+ nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
142
+ norm_layer(512),
143
+ activation,
144
+ )
145
+
146
+ if self.deep:
147
+ self.down4 = nn.Sequential(
148
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
149
+ norm_layer(512),
150
+ activation,
151
+ )
152
+
153
+ # resnet blocks
154
+ BN = []
155
+ for i in range(n_blocks):
156
+ BN += [
157
+ ResnetBlock_Adain(
158
+ 512,
159
+ latent_size=latent_size,
160
+ padding_type=padding_type,
161
+ activation=activation,
162
+ )
163
+ ]
164
+ self.BottleNeck = nn.Sequential(*BN)
165
+
166
+ if self.deep:
167
+ self.up4 = nn.Sequential(
168
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
169
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
170
+ nn.BatchNorm2d(512),
171
+ activation,
172
+ )
173
+ self.up3 = nn.Sequential(
174
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
175
+ nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
176
+ nn.BatchNorm2d(256),
177
+ activation,
178
+ )
179
+ self.up2 = nn.Sequential(
180
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
181
+ nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
182
+ nn.BatchNorm2d(128),
183
+ activation,
184
+ )
185
+ self.up1 = nn.Sequential(
186
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
187
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
188
+ nn.BatchNorm2d(64),
189
+ activation,
190
+ )
191
+ if self.use_last_act:
192
+ self.last_layer = nn.Sequential(
193
+ nn.ReflectionPad2d(3),
194
+ nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
195
+ torch.nn.Tanh(),
196
+ )
197
+ else:
198
+ self.last_layer = nn.Sequential(
199
+ nn.ReflectionPad2d(3),
200
+ nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
201
+ )
202
+
203
+ def to(self, device):
204
+ super().to(device)
205
+ self.device = device
206
+ self.imagenet_mean = self.imagenet_mean.to(device)
207
+ self.imagenet_std = self.imagenet_std.to(device)
208
+ return self
209
+
210
+ def forward(self, x: Iterable[np.ndarray], dlatents: torch.Tensor):
211
+ if self.use_last_act:
212
+ x = [self.to_tensor(_) for _ in x]
213
+ else:
214
+ x = [self.to_tensor_normalize(_) for _ in x]
215
+
216
+ x = torch.stack(x, dim=0)
217
+
218
+ x = x.to(self.device)
219
+
220
+ skip1 = self.first_layer(x)
221
+ skip2 = self.down1(skip1)
222
+ skip3 = self.down2(skip2)
223
+ if self.deep:
224
+ skip4 = self.down3(skip3)
225
+ x = self.down4(skip4)
226
+ else:
227
+ x = self.down3(skip3)
228
+
229
+ for i in range(len(self.BottleNeck)):
230
+ x = self.BottleNeck[i](x, dlatents)
231
+
232
+ if self.deep:
233
+ x = self.up4(x)
234
+
235
+ x = self.up3(x)
236
+ x = self.up2(x)
237
+ x = self.up1(x)
238
+ x = self.last_layer(x)
239
+
240
+ if self.use_last_act:
241
+ x = (x + 1) / 2
242
+ else:
243
+ x = x * self.imagenet_std + self.imagenet_mean
244
+
245
+ return x
src/Misc/types.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class CheckpointType(Enum):
5
+ OFFICIAL_224 = "official_224"
6
+ UNOFFICIAL = "none"
7
+
8
+
9
+ class FaceAlignmentType(Enum):
10
+ FFHQ = "ffhq"
11
+ DEFAULT = "none"
src/Misc/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ def tensor2img_denorm(tensor):
7
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
8
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
9
+ tensor = std * tensor.detach().cpu() + mean
10
+ img = tensor.numpy()
11
+ img = img.transpose(0, 2, 3, 1)[0]
12
+ img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
13
+ return img
14
+
15
+
16
+ def tensor2img(tensor):
17
+ tensor = tensor.detach().cpu().numpy()
18
+ img = tensor.transpose(0, 2, 3, 1)[0]
19
+ img = np.clip(img * 255, 0.0, 255.0).astype(np.uint8)
20
+ return img
21
+
22
+
23
+ def show_tensor(tensor, name):
24
+ img = cv2.cvtColor(tensor2img(tensor), cv2.COLOR_RGB2BGR)
25
+
26
+ cv2.namedWindow(name, cv2.WINDOW_NORMAL)
27
+ cv2.imshow(name, img)
28
+ cv2.waitKey()
src/PostProcess/GFPGAN/gfpgan.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from src.PostProcess.GFPGAN.stylegan2 import StyleGAN2GeneratorClean
7
+
8
+
9
+ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
10
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
11
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
12
+ Args:
13
+ out_size (int): The spatial size of outputs.
14
+ num_style_feat (int): Channel number of style features. Default: 512.
15
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
16
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
17
+ narrow (float): The narrow ratio for channels. Default: 1.
18
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
19
+ """
20
+
21
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
22
+ super(StyleGAN2GeneratorCSFT, self).__init__(
23
+ out_size,
24
+ num_style_feat=num_style_feat,
25
+ num_mlp=num_mlp,
26
+ channel_multiplier=channel_multiplier,
27
+ narrow=narrow)
28
+ self.sft_half = sft_half
29
+
30
+ def forward(self,
31
+ styles,
32
+ conditions,
33
+ input_is_latent=False,
34
+ noise=None,
35
+ randomize_noise=True,
36
+ truncation=1,
37
+ truncation_latent=None,
38
+ inject_index=None,
39
+ return_latents=False):
40
+ """Forward function for StyleGAN2GeneratorCSFT.
41
+ Args:
42
+ styles (list[Tensor]): Sample codes of styles.
43
+ conditions (list[Tensor]): SFT conditions to generators.
44
+ input_is_latent (bool): Whether input is latent style. Default: False.
45
+ noise (Tensor | None): Input noise or None. Default: None.
46
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
47
+ truncation (float): The truncation ratio. Default: 1.
48
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
49
+ inject_index (int | None): The injection index for mixing noise. Default: None.
50
+ return_latents (bool): Whether to return style latents. Default: False.
51
+ """
52
+ # style codes -> latents with Style MLP layer
53
+ if not input_is_latent:
54
+ styles = [self.style_mlp(s) for s in styles]
55
+ # noises
56
+ if noise is None:
57
+ if randomize_noise:
58
+ noise = [None] * self.num_layers # for each style conv layer
59
+ else: # use the stored noise
60
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
61
+ # style truncation
62
+ if truncation < 1:
63
+ style_truncation = []
64
+ for style in styles:
65
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
66
+ styles = style_truncation
67
+ # get style latents with injection
68
+ if len(styles) == 1:
69
+ inject_index = self.num_latent
70
+
71
+ if styles[0].ndim < 3:
72
+ # repeat latent code for all the layers
73
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
74
+ else: # used for encoder with different latent code for each layer
75
+ latent = styles[0]
76
+ elif len(styles) == 2: # mixing noises
77
+ if inject_index is None:
78
+ inject_index = random.randint(1, self.num_latent - 1)
79
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
80
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
81
+ latent = torch.cat([latent1, latent2], 1)
82
+
83
+ # main generation
84
+ out = self.constant_input(latent.shape[0])
85
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
86
+ skip = self.to_rgb1(out, latent[:, 1])
87
+
88
+ i = 1
89
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
90
+ noise[2::2], self.to_rgbs):
91
+ out = conv1(out, latent[:, i], noise=noise1)
92
+
93
+ # the conditions may have fewer levels
94
+ if i < len(conditions):
95
+ # SFT part to combine the conditions
96
+ if self.sft_half: # only apply SFT to half of the channels
97
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
98
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
99
+ out = torch.cat([out_same, out_sft], dim=1)
100
+ else: # apply SFT to all the channels
101
+ out = out * conditions[i - 1] + conditions[i]
102
+
103
+ out = conv2(out, latent[:, i + 1], noise=noise2)
104
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
105
+ i += 2
106
+
107
+ image = skip
108
+
109
+ if return_latents:
110
+ return image, latent
111
+ else:
112
+ return image, None
113
+
114
+
115
+ class ResBlock(torch.nn.Module):
116
+ """Residual block with bilinear upsampling/downsampling.
117
+ Args:
118
+ in_channels (int): Channel number of the input.
119
+ out_channels (int): Channel number of the output.
120
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
121
+ """
122
+
123
+ def __init__(self, in_channels, out_channels, mode='down'):
124
+ super(ResBlock, self).__init__()
125
+
126
+ self.conv1 = torch.nn.Conv2d(in_channels, in_channels, 3, 1, 1)
127
+ self.conv2 = torch.nn.Conv2d(in_channels, out_channels, 3, 1, 1)
128
+ self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, bias=False)
129
+ if mode == 'down':
130
+ self.scale_factor = 0.5
131
+ elif mode == 'up':
132
+ self.scale_factor = 2
133
+
134
+ def forward(self, x):
135
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
136
+ # upsample/downsample
137
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
138
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
139
+ # skip
140
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
141
+ skip = self.skip(x)
142
+ out = out + skip
143
+ return out
144
+
145
+
146
+ class GFPGANv1Clean(torch.nn.Module):
147
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
148
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
149
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
150
+ Args:
151
+ out_size (int): The spatial size of outputs.
152
+ num_style_feat (int): Channel number of style features. Default: 512.
153
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
154
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
155
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
156
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
157
+ input_is_latent (bool): Whether input is latent style. Default: False.
158
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
159
+ narrow (float): The narrow ratio for channels. Default: 1.
160
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ out_size,
166
+ num_style_feat=512,
167
+ channel_multiplier=1,
168
+ decoder_load_path=None,
169
+ fix_decoder=True,
170
+ # for stylegan decoder
171
+ num_mlp=8,
172
+ input_is_latent=False,
173
+ different_w=False,
174
+ narrow=1,
175
+ sft_half=False):
176
+
177
+ super(GFPGANv1Clean, self).__init__()
178
+ self.input_is_latent = input_is_latent
179
+ self.different_w = different_w
180
+ self.num_style_feat = num_style_feat
181
+
182
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
183
+ channels = {
184
+ '4': int(512 * unet_narrow),
185
+ '8': int(512 * unet_narrow),
186
+ '16': int(512 * unet_narrow),
187
+ '32': int(512 * unet_narrow),
188
+ '64': int(256 * channel_multiplier * unet_narrow),
189
+ '128': int(128 * channel_multiplier * unet_narrow),
190
+ '256': int(64 * channel_multiplier * unet_narrow),
191
+ '512': int(32 * channel_multiplier * unet_narrow),
192
+ '1024': int(16 * channel_multiplier * unet_narrow)
193
+ }
194
+
195
+ self.log_size = int(math.log(out_size, 2))
196
+ first_out_size = 2**(int(math.log(out_size, 2)))
197
+
198
+ self.conv_body_first = torch.nn.Conv2d(3, channels[f'{first_out_size}'], 1)
199
+
200
+ # downsample
201
+ in_channels = channels[f'{first_out_size}']
202
+ self.conv_body_down = torch.nn.ModuleList()
203
+ for i in range(self.log_size, 2, -1):
204
+ out_channels = channels[f'{2**(i - 1)}']
205
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
206
+ in_channels = out_channels
207
+
208
+ self.final_conv = torch.nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
209
+
210
+ # upsample
211
+ in_channels = channels['4']
212
+ self.conv_body_up = torch.nn.ModuleList()
213
+ for i in range(3, self.log_size + 1):
214
+ out_channels = channels[f'{2**i}']
215
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
216
+ in_channels = out_channels
217
+
218
+ # to RGB
219
+ self.toRGB = torch.nn.ModuleList()
220
+ for i in range(3, self.log_size + 1):
221
+ self.toRGB.append(torch.nn.Conv2d(channels[f'{2**i}'], 3, 1))
222
+
223
+ if different_w:
224
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
225
+ else:
226
+ linear_out_channel = num_style_feat
227
+
228
+ self.final_linear = torch.nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
229
+
230
+ # the decoder: stylegan2 generator with SFT modulations
231
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
232
+ out_size=out_size,
233
+ num_style_feat=num_style_feat,
234
+ num_mlp=num_mlp,
235
+ channel_multiplier=channel_multiplier,
236
+ narrow=narrow,
237
+ sft_half=sft_half)
238
+
239
+ # load pre-trained stylegan2 model if necessary
240
+ if decoder_load_path:
241
+ self.stylegan_decoder.load_state_dict(
242
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
243
+ # fix decoder without updating params
244
+ if fix_decoder:
245
+ for _, param in self.stylegan_decoder.named_parameters():
246
+ param.requires_grad = False
247
+
248
+ # for SFT modulations (scale and shift)
249
+ self.condition_scale = torch.nn.ModuleList()
250
+ self.condition_shift = torch.nn.ModuleList()
251
+ for i in range(3, self.log_size + 1):
252
+ out_channels = channels[f'{2**i}']
253
+ if sft_half:
254
+ sft_out_channels = out_channels
255
+ else:
256
+ sft_out_channels = out_channels * 2
257
+ self.condition_scale.append(
258
+ torch.nn.Sequential(
259
+ torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1), torch.nn.LeakyReLU(0.2, True),
260
+ torch.nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
261
+ self.condition_shift.append(
262
+ torch.nn.Sequential(
263
+ torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1), torch.nn.LeakyReLU(0.2, True),
264
+ torch.nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
265
+
266
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
267
+ """Forward function for GFPGANv1Clean.
268
+ Args:
269
+ x (Tensor): Input images.
270
+ return_latents (bool): Whether to return style latents. Default: False.
271
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
272
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
273
+ """
274
+ conditions = []
275
+ unet_skips = []
276
+ out_rgbs = []
277
+
278
+ # encoder
279
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
280
+ for i in range(self.log_size - 2):
281
+ feat = self.conv_body_down[i](feat)
282
+ unet_skips.insert(0, feat)
283
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
284
+
285
+ # style code
286
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
287
+ if self.different_w:
288
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
289
+
290
+ # decode
291
+ for i in range(self.log_size - 2):
292
+ # add unet skip
293
+ feat = feat + unet_skips[i]
294
+ # ResUpLayer
295
+ feat = self.conv_body_up[i](feat)
296
+ # generate scale and shift for SFT layers
297
+ scale = self.condition_scale[i](feat)
298
+ conditions.append(scale.clone())
299
+ shift = self.condition_shift[i](feat)
300
+ conditions.append(shift.clone())
301
+ # generate rgb images
302
+ if return_rgb:
303
+ out_rgbs.append(self.toRGB[i](feat))
304
+
305
+ # decoder
306
+ image, _ = self.stylegan_decoder([style_code],
307
+ conditions,
308
+ return_latents=return_latents,
309
+ input_is_latent=self.input_is_latent,
310
+ randomize_noise=randomize_noise)
311
+
312
+ return image, out_rgbs
313
+
314
+
315
+ class GFPGANer(GFPGANv1Clean):
316
+ """Helper for restoration with GFPGAN."""
317
+
318
+ def __init__(self):
319
+ super().__init__(out_size=512, num_style_feat=512, channel_multiplier=2,
320
+ decoder_load_path=None, fix_decoder=False, num_mlp=8, input_is_latent=True,
321
+ different_w=True, narrow=1, sft_half=True)
322
+
323
+ self.min_max = (-1, 1)
324
+
325
+ @torch.no_grad()
326
+ def enhance(self, img, weight=0.5):
327
+ n, c, h, w = img.shape
328
+ img = F.interpolate(img, size=(512, 512), mode="bilinear")
329
+
330
+ img = (img - 0.5) / 0.5
331
+
332
+ try:
333
+ restored_faces = self.forward(img, return_rgb=False, weight=weight)[0]
334
+ except RuntimeError as error:
335
+ print(f'\tFailed inference for GFPGAN: {error}.')
336
+ restored_faces = img
337
+
338
+ restored_faces.clamp_(*self.min_max)
339
+ restored_faces = (restored_faces - self.min_max[0]) / (self.min_max[1] - self.min_max[0])
340
+
341
+ return F.interpolate(restored_faces, size=(h, w), mode="bilinear")
src/PostProcess/GFPGAN/stylegan2.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class NormStyleCode(torch.nn.Module):
8
+
9
+ def forward(self, x):
10
+ """Normalize the style codes.
11
+ Args:
12
+ x (Tensor): Style codes with shape (b, c).
13
+ Returns:
14
+ Tensor: Normalized tensor.
15
+ """
16
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
17
+
18
+
19
+ class ModulatedConv2d(torch.nn.Module):
20
+ """Modulated Conv2d used in StyleGAN2.
21
+ There is no bias in ModulatedConv2d.
22
+ Args:
23
+ in_channels (int): Channel number of the input.
24
+ out_channels (int): Channel number of the output.
25
+ kernel_size (int): Size of the convolving kernel.
26
+ num_style_feat (int): Channel number of style features.
27
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
28
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
29
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
30
+ """
31
+
32
+ def __init__(self,
33
+ in_channels,
34
+ out_channels,
35
+ kernel_size,
36
+ num_style_feat,
37
+ demodulate=True,
38
+ sample_mode=None,
39
+ eps=1e-8):
40
+ super(ModulatedConv2d, self).__init__()
41
+ self.in_channels = in_channels
42
+ self.out_channels = out_channels
43
+ self.kernel_size = kernel_size
44
+ self.demodulate = demodulate
45
+ self.sample_mode = sample_mode
46
+ self.eps = eps
47
+
48
+ # modulation inside each modulated conv
49
+ self.modulation = torch.nn.Linear(num_style_feat, in_channels, bias=True)
50
+ # initialization
51
+ # default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
52
+
53
+ self.weight = torch.nn.Parameter(
54
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
55
+ math.sqrt(in_channels * kernel_size**2))
56
+ self.padding = kernel_size // 2
57
+
58
+ def forward(self, x, style):
59
+ """Forward function.
60
+ Args:
61
+ x (Tensor): Tensor with shape (b, c, h, w).
62
+ style (Tensor): Tensor with shape (b, num_style_feat).
63
+ Returns:
64
+ Tensor: Modulated tensor after convolution.
65
+ """
66
+ b, c, h, w = x.shape # c = c_in
67
+ # weight modulation
68
+ style = self.modulation(style).view(b, 1, c, 1, 1)
69
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
70
+ weight = self.weight * style # (b, c_out, c_in, k, k)
71
+
72
+ if self.demodulate:
73
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
74
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
75
+
76
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
77
+
78
+ # upsample or downsample if necessary
79
+ if self.sample_mode == 'upsample':
80
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
81
+ elif self.sample_mode == 'downsample':
82
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
83
+
84
+ b, c, h, w = x.shape
85
+ x = x.view(1, b * c, h, w)
86
+ # weight: (b*c_out, c_in, k, k), groups=b
87
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
88
+ out = out.view(b, self.out_channels, *out.shape[2:4])
89
+
90
+ return out
91
+
92
+ def __repr__(self):
93
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
94
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
95
+
96
+
97
+ class StyleConv(torch.nn.Module):
98
+ """Style conv used in StyleGAN2.
99
+ Args:
100
+ in_channels (int): Channel number of the input.
101
+ out_channels (int): Channel number of the output.
102
+ kernel_size (int): Size of the convolving kernel.
103
+ num_style_feat (int): Channel number of style features.
104
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
105
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
106
+ """
107
+
108
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
109
+ super(StyleConv, self).__init__()
110
+ self.modulated_conv = ModulatedConv2d(
111
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
112
+ self.weight = torch.nn.Parameter(torch.zeros(1)) # for noise injection
113
+ self.bias = torch.nn.Parameter(torch.zeros(1, out_channels, 1, 1))
114
+ self.activate = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
115
+
116
+ def forward(self, x, style, noise=None):
117
+ # modulate
118
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
119
+ # noise injection
120
+ if noise is None:
121
+ b, _, h, w = out.shape
122
+ noise = out.new_empty(b, 1, h, w).normal_()
123
+ out = out + self.weight * noise
124
+ # add bias
125
+ out = out + self.bias
126
+ # activation
127
+ out = self.activate(out)
128
+ return out
129
+
130
+
131
+ class ToRGB(torch.nn.Module):
132
+ """To RGB (image space) from features.
133
+ Args:
134
+ in_channels (int): Channel number of input.
135
+ num_style_feat (int): Channel number of style features.
136
+ upsample (bool): Whether to upsample. Default: True.
137
+ """
138
+
139
+ def __init__(self, in_channels, num_style_feat, upsample=True):
140
+ super(ToRGB, self).__init__()
141
+ self.upsample = upsample
142
+ self.modulated_conv = ModulatedConv2d(
143
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
144
+ self.bias = torch.nn.Parameter(torch.zeros(1, 3, 1, 1))
145
+
146
+ def forward(self, x, style, skip=None):
147
+ """Forward function.
148
+ Args:
149
+ x (Tensor): Feature tensor with shape (b, c, h, w).
150
+ style (Tensor): Tensor with shape (b, num_style_feat).
151
+ skip (Tensor): Base/skip tensor. Default: None.
152
+ Returns:
153
+ Tensor: RGB images.
154
+ """
155
+ out = self.modulated_conv(x, style)
156
+ out = out + self.bias
157
+ if skip is not None:
158
+ if self.upsample:
159
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
160
+ out = out + skip
161
+ return out
162
+
163
+
164
+ class ConstantInput(torch.nn.Module):
165
+ """Constant input.
166
+ Args:
167
+ num_channel (int): Channel number of constant input.
168
+ size (int): Spatial size of constant input.
169
+ """
170
+
171
+ def __init__(self, num_channel, size):
172
+ super(ConstantInput, self).__init__()
173
+ self.weight = torch.nn.Parameter(torch.randn(1, num_channel, size, size))
174
+
175
+ def forward(self, batch):
176
+ out = self.weight.repeat(batch, 1, 1, 1)
177
+ return out
178
+
179
+
180
+ class StyleGAN2GeneratorClean(torch.nn.Module):
181
+ """Clean version of StyleGAN2 Generator.
182
+ Args:
183
+ out_size (int): The spatial size of outputs.
184
+ num_style_feat (int): Channel number of style features. Default: 512.
185
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
186
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
187
+ narrow (float): Narrow ratio for channels. Default: 1.0.
188
+ """
189
+
190
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
191
+ super(StyleGAN2GeneratorClean, self).__init__()
192
+ # Style MLP layers
193
+ self.num_style_feat = num_style_feat
194
+ style_mlp_layers = [NormStyleCode()]
195
+ for i in range(num_mlp):
196
+ style_mlp_layers.extend(
197
+ [torch.nn.Linear(num_style_feat, num_style_feat, bias=True),
198
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)])
199
+ self.style_mlp = torch.nn.Sequential(*style_mlp_layers)
200
+ # initialization
201
+ # default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
202
+
203
+ # channel list
204
+ channels = {
205
+ '4': int(512 * narrow),
206
+ '8': int(512 * narrow),
207
+ '16': int(512 * narrow),
208
+ '32': int(512 * narrow),
209
+ '64': int(256 * channel_multiplier * narrow),
210
+ '128': int(128 * channel_multiplier * narrow),
211
+ '256': int(64 * channel_multiplier * narrow),
212
+ '512': int(32 * channel_multiplier * narrow),
213
+ '1024': int(16 * channel_multiplier * narrow)
214
+ }
215
+ self.channels = channels
216
+
217
+ self.constant_input = ConstantInput(channels['4'], size=4)
218
+ self.style_conv1 = StyleConv(
219
+ channels['4'],
220
+ channels['4'],
221
+ kernel_size=3,
222
+ num_style_feat=num_style_feat,
223
+ demodulate=True,
224
+ sample_mode=None)
225
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
226
+
227
+ self.log_size = int(math.log(out_size, 2))
228
+ self.num_layers = (self.log_size - 2) * 2 + 1
229
+ self.num_latent = self.log_size * 2 - 2
230
+
231
+ self.style_convs = torch.nn.ModuleList()
232
+ self.to_rgbs = torch.nn.ModuleList()
233
+ self.noises = torch.nn.Module()
234
+
235
+ in_channels = channels['4']
236
+ # noise
237
+ for layer_idx in range(self.num_layers):
238
+ resolution = 2**((layer_idx + 5) // 2)
239
+ shape = [1, 1, resolution, resolution]
240
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
241
+ # style convs and to_rgbs
242
+ for i in range(3, self.log_size + 1):
243
+ out_channels = channels[f'{2**i}']
244
+ self.style_convs.append(
245
+ StyleConv(
246
+ in_channels,
247
+ out_channels,
248
+ kernel_size=3,
249
+ num_style_feat=num_style_feat,
250
+ demodulate=True,
251
+ sample_mode='upsample'))
252
+ self.style_convs.append(
253
+ StyleConv(
254
+ out_channels,
255
+ out_channels,
256
+ kernel_size=3,
257
+ num_style_feat=num_style_feat,
258
+ demodulate=True,
259
+ sample_mode=None))
260
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
261
+ in_channels = out_channels
262
+
263
+ def make_noise(self):
264
+ """Make noise for noise injection."""
265
+ device = self.constant_input.weight.device
266
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
267
+
268
+ for i in range(3, self.log_size + 1):
269
+ for _ in range(2):
270
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
271
+
272
+ return noises
273
+
274
+ def get_latent(self, x):
275
+ return self.style_mlp(x)
276
+
277
+ def mean_latent(self, num_latent):
278
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
279
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
280
+ return latent
281
+
282
+ def forward(self,
283
+ styles,
284
+ input_is_latent=False,
285
+ noise=None,
286
+ randomize_noise=True,
287
+ truncation=1,
288
+ truncation_latent=None,
289
+ inject_index=None,
290
+ return_latents=False):
291
+ """Forward function for StyleGAN2GeneratorClean.
292
+ Args:
293
+ styles (list[Tensor]): Sample codes of styles.
294
+ input_is_latent (bool): Whether input is latent style. Default: False.
295
+ noise (Tensor | None): Input noise or None. Default: None.
296
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
297
+ truncation (float): The truncation ratio. Default: 1.
298
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
299
+ inject_index (int | None): The injection index for mixing noise. Default: None.
300
+ return_latents (bool): Whether to return style latents. Default: False.
301
+ """
302
+ # style codes -> latents with Style MLP layer
303
+ if not input_is_latent:
304
+ styles = [self.style_mlp(s) for s in styles]
305
+ # noises
306
+ if noise is None:
307
+ if randomize_noise:
308
+ noise = [None] * self.num_layers # for each style conv layer
309
+ else: # use the stored noise
310
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
311
+ # style truncation
312
+ if truncation < 1:
313
+ style_truncation = []
314
+ for style in styles:
315
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
316
+ styles = style_truncation
317
+ # get style latents with injection
318
+ if len(styles) == 1:
319
+ inject_index = self.num_latent
320
+
321
+ if styles[0].ndim < 3:
322
+ # repeat latent code for all the layers
323
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
324
+ else: # used for encoder with different latent code for each layer
325
+ latent = styles[0]
326
+ elif len(styles) == 2: # mixing noises
327
+ if inject_index is None:
328
+ inject_index = random.randint(1, self.num_latent - 1)
329
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
330
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
331
+ latent = torch.cat([latent1, latent2], 1)
332
+
333
+ # main generation
334
+ out = self.constant_input(latent.shape[0])
335
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
336
+ skip = self.to_rgb1(out, latent[:, 1])
337
+
338
+ i = 1
339
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
340
+ noise[2::2], self.to_rgbs):
341
+ out = conv1(out, latent[:, i], noise=noise1)
342
+ out = conv2(out, latent[:, i + 1], noise=noise2)
343
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
344
+ i += 2
345
+
346
+ image = skip
347
+
348
+ if return_latents:
349
+ return image, latent
350
+ else:
351
+ return image, None
src/PostProcess/ParsingModel/model.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from src.PostProcess.ParsingModel.resnet import Resnet18
10
+
11
+ from src.PostProcess.utils import encode_segmentation_rgb_batch
12
+ from typing import Tuple
13
+
14
+
15
+ class ConvBNReLU(nn.Module):
16
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
17
+ super(ConvBNReLU, self).__init__()
18
+ self.conv = nn.Conv2d(
19
+ in_chan,
20
+ out_chan,
21
+ kernel_size=ks,
22
+ stride=stride,
23
+ padding=padding,
24
+ bias=False,
25
+ )
26
+ self.bn = nn.BatchNorm2d(out_chan)
27
+ self.init_weight()
28
+
29
+ def forward(self, x):
30
+ x = self.conv(x)
31
+ x = F.relu(self.bn(x))
32
+ return x
33
+
34
+ def init_weight(self):
35
+ for ly in self.children():
36
+ if isinstance(ly, nn.Conv2d):
37
+ nn.init.kaiming_normal_(ly.weight, a=1)
38
+ if ly.bias is not None:
39
+ nn.init.constant_(ly.bias, 0)
40
+
41
+
42
+ class BiSeNetOutput(nn.Module):
43
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
44
+ super(BiSeNetOutput, self).__init__()
45
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
46
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
47
+ self.init_weight()
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ x = self.conv_out(x)
52
+ return x
53
+
54
+ def init_weight(self):
55
+ for ly in self.children():
56
+ if isinstance(ly, nn.Conv2d):
57
+ nn.init.kaiming_normal_(ly.weight, a=1)
58
+ if ly.bias is not None:
59
+ nn.init.constant_(ly.bias, 0)
60
+
61
+ def get_params(self):
62
+ wd_params, nowd_params = [], []
63
+ for name, module in self.named_modules():
64
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
65
+ wd_params.append(module.weight)
66
+ if module.bias is not None:
67
+ nowd_params.append(module.bias)
68
+ elif isinstance(module, nn.BatchNorm2d):
69
+ nowd_params += list(module.parameters())
70
+ return wd_params, nowd_params
71
+
72
+
73
+ class AttentionRefinementModule(nn.Module):
74
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
75
+ super(AttentionRefinementModule, self).__init__()
76
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
77
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
78
+ self.bn_atten = nn.BatchNorm2d(out_chan)
79
+ self.sigmoid_atten = nn.Sigmoid()
80
+ self.init_weight()
81
+
82
+ def forward(self, x):
83
+ feat = self.conv(x)
84
+ atten = F.avg_pool2d(feat, feat.size()[2:])
85
+ atten = self.conv_atten(atten)
86
+ atten = self.bn_atten(atten)
87
+ atten = self.sigmoid_atten(atten)
88
+ out = torch.mul(feat, atten)
89
+ return out
90
+
91
+ def init_weight(self):
92
+ for ly in self.children():
93
+ if isinstance(ly, nn.Conv2d):
94
+ nn.init.kaiming_normal_(ly.weight, a=1)
95
+ if ly.bias is not None:
96
+ nn.init.constant_(ly.bias, 0)
97
+
98
+
99
+ class ContextPath(nn.Module):
100
+ def __init__(self, *args, **kwargs):
101
+ super(ContextPath, self).__init__()
102
+ self.resnet = Resnet18()
103
+ self.arm16 = AttentionRefinementModule(256, 128)
104
+ self.arm32 = AttentionRefinementModule(512, 128)
105
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
106
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
107
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
108
+
109
+ self.init_weight()
110
+
111
+ def forward(self, x):
112
+ H0, W0 = x.size()[2:]
113
+ feat8, feat16, feat32 = self.resnet(x)
114
+ H8, W8 = feat8.size()[2:]
115
+ H16, W16 = feat16.size()[2:]
116
+ H32, W32 = feat32.size()[2:]
117
+
118
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
119
+ avg = self.conv_avg(avg)
120
+ avg_up = F.interpolate(avg, (H32, W32), mode="nearest")
121
+
122
+ feat32_arm = self.arm32(feat32)
123
+ feat32_sum = feat32_arm + avg_up
124
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode="nearest")
125
+ feat32_up = self.conv_head32(feat32_up)
126
+
127
+ feat16_arm = self.arm16(feat16)
128
+ feat16_sum = feat16_arm + feat32_up
129
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode="nearest")
130
+ feat16_up = self.conv_head16(feat16_up)
131
+
132
+ return feat8, feat16_up, feat32_up # x8, x8, x16
133
+
134
+ def init_weight(self):
135
+ for ly in self.children():
136
+ if isinstance(ly, nn.Conv2d):
137
+ nn.init.kaiming_normal_(ly.weight, a=1)
138
+ if ly.bias is not None:
139
+ nn.init.constant_(ly.bias, 0)
140
+
141
+ def get_params(self):
142
+ wd_params, nowd_params = [], []
143
+ for name, module in self.named_modules():
144
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
145
+ wd_params.append(module.weight)
146
+ if module.bias is not None:
147
+ nowd_params.append(module.bias)
148
+ elif isinstance(module, nn.BatchNorm2d):
149
+ nowd_params += list(module.parameters())
150
+ return wd_params, nowd_params
151
+
152
+
153
+ # This is not used, since I replace this with the resnet feature with the same size
154
+ class SpatialPath(nn.Module):
155
+ def __init__(self, *args, **kwargs):
156
+ super(SpatialPath, self).__init__()
157
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
158
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
159
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
160
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
161
+ self.init_weight()
162
+
163
+ def forward(self, x):
164
+ feat = self.conv1(x)
165
+ feat = self.conv2(feat)
166
+ feat = self.conv3(feat)
167
+ feat = self.conv_out(feat)
168
+ return feat
169
+
170
+ def init_weight(self):
171
+ for ly in self.children():
172
+ if isinstance(ly, nn.Conv2d):
173
+ nn.init.kaiming_normal_(ly.weight, a=1)
174
+ if ly.bias is not None:
175
+ nn.init.constant_(ly.bias, 0)
176
+
177
+ def get_params(self):
178
+ wd_params, nowd_params = [], []
179
+ for name, module in self.named_modules():
180
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
181
+ wd_params.append(module.weight)
182
+ if module.bias is not None:
183
+ nowd_params.append(module.bias)
184
+ elif isinstance(module, nn.BatchNorm2d):
185
+ nowd_params += list(module.parameters())
186
+ return wd_params, nowd_params
187
+
188
+
189
+ class FeatureFusionModule(nn.Module):
190
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
191
+ super(FeatureFusionModule, self).__init__()
192
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
193
+ self.conv1 = nn.Conv2d(
194
+ out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False
195
+ )
196
+ self.conv2 = nn.Conv2d(
197
+ out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False
198
+ )
199
+ self.relu = nn.ReLU(inplace=True)
200
+ self.sigmoid = nn.Sigmoid()
201
+ self.init_weight()
202
+
203
+ def forward(self, fsp, fcp):
204
+ fcat = torch.cat([fsp, fcp], dim=1)
205
+ feat = self.convblk(fcat)
206
+ atten = F.avg_pool2d(feat, feat.size()[2:])
207
+ atten = self.conv1(atten)
208
+ atten = self.relu(atten)
209
+ atten = self.conv2(atten)
210
+ atten = self.sigmoid(atten)
211
+ feat_atten = torch.mul(feat, atten)
212
+ feat_out = feat_atten + feat
213
+ return feat_out
214
+
215
+ def init_weight(self):
216
+ for ly in self.children():
217
+ if isinstance(ly, nn.Conv2d):
218
+ nn.init.kaiming_normal_(ly.weight, a=1)
219
+ if ly.bias is not None:
220
+ nn.init.constant_(ly.bias, 0)
221
+
222
+ def get_params(self):
223
+ wd_params, nowd_params = [], []
224
+ for name, module in self.named_modules():
225
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
226
+ wd_params.append(module.weight)
227
+ if module.bias is not None:
228
+ nowd_params.append(module.bias)
229
+ elif isinstance(module, nn.BatchNorm2d):
230
+ nowd_params += list(module.parameters())
231
+ return wd_params, nowd_params
232
+
233
+
234
+ class BiSeNet(nn.Module):
235
+ def __init__(self, n_classes, *args, **kwargs):
236
+ super(BiSeNet, self).__init__()
237
+ self.cp = ContextPath()
238
+ # here self.sp is deleted
239
+ self.ffm = FeatureFusionModule(256, 256)
240
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
241
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
242
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
243
+ self.init_weight()
244
+
245
+ def get_mask(
246
+ self, x: torch.Tensor, crop_size: int
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ x = F.interpolate(x, size=(512, 512))
249
+
250
+ parsed_face = self.forward(x)[0]
251
+
252
+ parsed_face = torch.argmax(parsed_face, dim=1, keepdim=True)
253
+
254
+ parsed_face = encode_segmentation_rgb_batch(parsed_face)
255
+
256
+ parsed_face = torch.where(
257
+ torch.sum(parsed_face, dim=[1, 2, 3], keepdim=True) > 5000,
258
+ parsed_face,
259
+ torch.zeros_like(parsed_face),
260
+ )
261
+
262
+ ignore_mask_ids = torch.sum(parsed_face, dim=[1, 2, 3]) == 0
263
+
264
+ parsed_face = parsed_face.float().mul_(1 / 255.0)
265
+
266
+ parsed_face = F.interpolate(
267
+ parsed_face, size=(crop_size, crop_size), mode="bilinear"
268
+ )
269
+
270
+ parsed_face = torch.sum(parsed_face, dim=1, keepdim=True)
271
+
272
+ return parsed_face, ignore_mask_ids
273
+
274
+ def forward(self, x):
275
+ H, W = x.size()[2:]
276
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
277
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
278
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
279
+
280
+ feat_out = self.conv_out(feat_fuse)
281
+ feat_out16 = self.conv_out16(feat_cp8)
282
+ feat_out32 = self.conv_out32(feat_cp16)
283
+
284
+ feat_out = F.interpolate(feat_out, (H, W), mode="bilinear", align_corners=True)
285
+ feat_out16 = F.interpolate(
286
+ feat_out16, (H, W), mode="bilinear", align_corners=True
287
+ )
288
+ feat_out32 = F.interpolate(
289
+ feat_out32, (H, W), mode="bilinear", align_corners=True
290
+ )
291
+ return feat_out, feat_out16, feat_out32
292
+
293
+ def init_weight(self):
294
+ for ly in self.children():
295
+ if isinstance(ly, nn.Conv2d):
296
+ nn.init.kaiming_normal_(ly.weight, a=1)
297
+ if ly.bias is not None:
298
+ nn.init.constant_(ly.bias, 0)
299
+
300
+ def get_params(self):
301
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
302
+ for name, child in self.named_children():
303
+ child_wd_params, child_nowd_params = child.get_params()
304
+ if isinstance(child, FeatureFusionModule) or isinstance(
305
+ child, BiSeNetOutput
306
+ ):
307
+ lr_mul_wd_params += child_wd_params
308
+ lr_mul_nowd_params += child_nowd_params
309
+ else:
310
+ wd_params += child_wd_params
311
+ nowd_params += child_nowd_params
312
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
313
+
314
+
315
+ if __name__ == "__main__":
316
+ net = BiSeNet(19)
317
+ net.cuda()
318
+ net.eval()
319
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
320
+ out, out16, out32 = net(in_ten)
321
+ print(out.shape)
322
+
323
+ net.get_params()
src/PostProcess/ParsingModel/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(
17
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
18
+ )
19
+
20
+
21
+ class BasicBlock(nn.Module):
22
+ def __init__(self, in_chan, out_chan, stride=1):
23
+ super(BasicBlock, self).__init__()
24
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
25
+ self.bn1 = nn.BatchNorm2d(out_chan)
26
+ self.conv2 = conv3x3(out_chan, out_chan)
27
+ self.bn2 = nn.BatchNorm2d(out_chan)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ if in_chan != out_chan or stride != 1:
31
+ self.downsample = nn.Sequential(
32
+ nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum - 1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
62
+ self.bn1 = nn.BatchNorm2d(64)
63
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
64
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
65
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
66
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
67
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
68
+ self.init_weight()
69
+
70
+ def forward(self, x):
71
+ x = self.conv1(x)
72
+ x = F.relu(self.bn1(x))
73
+ x = self.maxpool(x)
74
+
75
+ x = self.layer1(x)
76
+ feat8 = self.layer2(x) # 1/8
77
+ feat16 = self.layer3(feat8) # 1/16
78
+ feat32 = self.layer4(feat16) # 1/32
79
+ return feat8, feat16, feat32
80
+
81
+ def init_weight(self):
82
+ state_dict = modelzoo.load_url(resnet18_url)
83
+ self_state_dict = self.state_dict()
84
+ for k, v in state_dict.items():
85
+ if "fc" in k:
86
+ continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if module.bias is not None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
src/PostProcess/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Tuple
6
+
7
+
8
+ class SoftErosion(torch.nn.Module):
9
+ def __init__(
10
+ self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1
11
+ ):
12
+ super(SoftErosion, self).__init__()
13
+ r = kernel_size // 2
14
+ self.padding = r
15
+ self.iterations = iterations
16
+ self.threshold = threshold
17
+
18
+ # Create kernel
19
+ y_indices, x_indices = torch.meshgrid(
20
+ torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size)
21
+ )
22
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
23
+ kernel = dist.max() - dist
24
+ kernel /= kernel.sum()
25
+ kernel = kernel.view(1, 1, *kernel.shape)
26
+ self.register_buffer("weight", kernel)
27
+
28
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
29
+ for i in range(self.iterations - 1):
30
+ x = torch.min(
31
+ x,
32
+ F.conv2d(
33
+ x, weight=self.weight, groups=x.shape[1], padding=self.padding
34
+ ),
35
+ )
36
+ x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
37
+
38
+ mask = x >= self.threshold
39
+
40
+ x[mask] = 1.0
41
+ # add small epsilon to avoid Nans
42
+ x[~mask] /= (x[~mask].max() + 1e-7)
43
+
44
+ return x, mask
45
+
46
+
47
+ def encode_segmentation_rgb(
48
+ segmentation: np.ndarray, no_neck: bool = True
49
+ ) -> np.ndarray:
50
+ parse = segmentation
51
+ # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
52
+ face_part_ids = (
53
+ [1, 2, 3, 4, 5, 6, 10, 12, 13]
54
+ if no_neck
55
+ else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
56
+ )
57
+ mouth_id = 11
58
+ # hair_id = 17
59
+ face_map = np.zeros([parse.shape[0], parse.shape[1]])
60
+ mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
61
+ # hair_map = np.zeros([parse.shape[0], parse.shape[1]])
62
+
63
+ for valid_id in face_part_ids:
64
+ valid_index = np.where(parse == valid_id)
65
+ face_map[valid_index] = 255
66
+ valid_index = np.where(parse == mouth_id)
67
+ mouth_map[valid_index] = 255
68
+ # valid_index = np.where(parse==hair_id)
69
+ # hair_map[valid_index] = 255
70
+ # return np.stack([face_map, mouth_map,hair_map], axis=2)
71
+ return np.stack([face_map, mouth_map], axis=2)
72
+
73
+
74
+ def encode_segmentation_rgb_batch(
75
+ segmentation: torch.Tensor, no_neck: bool = True
76
+ ) -> torch.Tensor:
77
+ # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py
78
+ face_part_ids = (
79
+ [1, 2, 3, 4, 5, 6, 10, 12, 13]
80
+ if no_neck
81
+ else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
82
+ )
83
+ mouth_id = 11
84
+ # hair_id = 17
85
+ segmentation = segmentation.int()
86
+ face_map = torch.zeros_like(segmentation)
87
+ mouth_map = torch.zeros_like(segmentation)
88
+ # hair_map = np.zeros([parse.shape[0], parse.shape[1]])
89
+
90
+ white_tensor = face_map + 255
91
+ for valid_id in face_part_ids:
92
+ face_map = torch.where(segmentation == valid_id, white_tensor, face_map)
93
+ mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map)
94
+
95
+ return torch.cat([face_map, mouth_map], dim=1)
96
+
97
+
98
+ def postprocess(
99
+ swapped_face: np.ndarray,
100
+ target: np.ndarray,
101
+ target_mask: np.ndarray,
102
+ smooth_mask: torch.nn.Module,
103
+ ) -> np.ndarray:
104
+ # target_mask = cv2.resize(target_mask, (self.size, self.size))
105
+
106
+ mask_tensor = (
107
+ torch.from_numpy(target_mask.copy().transpose((2, 0, 1)))
108
+ .float()
109
+ .mul_(1 / 255.0)
110
+ .cuda()
111
+ )
112
+ face_mask_tensor = mask_tensor[0] + mask_tensor[1]
113
+
114
+ soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
115
+ soft_face_mask_tensor.squeeze_()
116
+
117
+ soft_face_mask = soft_face_mask_tensor.cpu().numpy()
118
+ soft_face_mask = soft_face_mask[:, :, np.newaxis]
119
+
120
+ result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
121
+ result = result[:, :, ::-1] # .astype(np.uint8)
122
+ return result
src/model_loader.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torch.utils import model_zoo
4
+ import requests
5
+ from tqdm import tqdm
6
+ from pathlib import Path
7
+
8
+ from src.FaceDetector.face_detector import FaceDetector
9
+ from src.FaceId.faceid import FaceId
10
+ from src.Generator.fs_networks_fix import Generator_Adain_Upsample
11
+ from src.PostProcess.ParsingModel.model import BiSeNet
12
+ from src.PostProcess.GFPGAN.gfpgan import GFPGANer
13
+ from src.Blend.blend import BlendModule
14
+
15
+
16
+ model = namedtuple("model", ["url", "model"])
17
+
18
+ models = {
19
+ "face_detector": model(
20
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/face_detector_scrfd_10g_bnkps.onnx",
21
+ model=FaceDetector,
22
+ ),
23
+ "arcface": model(
24
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/arcface_net.jit",
25
+ model=FaceId,
26
+ ),
27
+ "generator_224": model(
28
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_224_latest_net_G.pth",
29
+ model=Generator_Adain_Upsample,
30
+ ),
31
+ "generator_512": model(
32
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/simswap_512_390000_net_G.pth",
33
+ model=Generator_Adain_Upsample,
34
+ ),
35
+ "parsing_model": model(
36
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/weights/parsing_model_79999_iter.pth",
37
+ model=BiSeNet,
38
+ ),
39
+ "gfpgan": model(
40
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.1/GFPGANv1.4_ema.pth",
41
+ model=GFPGANer,
42
+ ),
43
+ "blend_module": model(
44
+ url="https://github.com/mike9251/simswap-inference-pytorch/releases/download/v1.2/blend_module.jit",
45
+ model=BlendModule
46
+ )
47
+ }
48
+
49
+
50
+ def get_model(
51
+ model_name: str,
52
+ device: torch.device,
53
+ load_state_dice: bool,
54
+ model_path: Path,
55
+ **kwargs,
56
+ ):
57
+ dst_dir = Path.cwd() / "weights"
58
+ dst_dir.mkdir(exist_ok=True)
59
+
60
+ url = models[model_name].url if not model_path.is_file() else str(model_path)
61
+
62
+ if load_state_dice:
63
+ model = models[model_name].model(**kwargs)
64
+
65
+ if Path(url).is_file():
66
+ state_dict = torch.load(url)
67
+ else:
68
+ state_dict = model_zoo.load_url(
69
+ url,
70
+ model_dir=str(dst_dir),
71
+ progress=True,
72
+ map_location="cpu",
73
+ )
74
+
75
+ model.load_state_dict(state_dict)
76
+
77
+ model.to(device)
78
+ model.eval()
79
+ else:
80
+ dst_path = Path(url)
81
+
82
+ if not dst_path.is_file():
83
+ dst_path = dst_dir / Path(url).name
84
+
85
+ if not dst_path.is_file():
86
+ print(f"Downloading: '{url}' to {dst_path}")
87
+ response = requests.get(url, stream=True)
88
+ if int(response.status_code) == 200:
89
+ file_size = int(response.headers["Content-Length"]) / (2 ** 20)
90
+ chunk_size = 1024
91
+ bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n:3.1f}M/{total:3.1f}M [{elapsed}<{remaining}]"
92
+ with open(dst_path, "wb") as handle:
93
+ with tqdm(total=file_size, bar_format=bar_format) as pbar:
94
+ for data in response.iter_content(chunk_size=chunk_size):
95
+ handle.write(data)
96
+ pbar.update(len(data) / (2 ** 20))
97
+ else:
98
+ raise ValueError(
99
+ f"Couldn't download weights {url}. Specify weights for the '{model_name}' model manually."
100
+ )
101
+
102
+ kwargs.update({"model_path": str(dst_path), "device": device})
103
+
104
+ model = models[model_name].model(**kwargs)
105
+
106
+ return model
src/simswap.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from typing import Iterable, Tuple, Union
5
+ from pathlib import Path
6
+ from torchvision import transforms
7
+ import kornia
8
+ from omegaconf import DictConfig
9
+
10
+ from src.FaceDetector.face_detector import Detection
11
+ from src.FaceAlign.face_align import align_face, inverse_transform_batch
12
+ from src.PostProcess.utils import SoftErosion
13
+ from src.model_loader import get_model
14
+ from src.Misc.types import CheckpointType, FaceAlignmentType
15
+ from src.Misc.utils import tensor2img
16
+
17
+
18
+ class SimSwap:
19
+ def __init__(
20
+ self,
21
+ config: DictConfig,
22
+ id_image: Union[np.ndarray, None] = None,
23
+ specific_image: Union[np.ndarray, None] = None,
24
+ ):
25
+
26
+ self.id_image: Union[np.ndarray, None] = id_image
27
+ self.id_latent: Union[torch.Tensor, None] = None
28
+ self.specific_id_image: Union[np.ndarray, None] = specific_image
29
+ self.specific_latent: Union[torch.Tensor, None] = None
30
+
31
+ self.use_mask: Union[bool, None] = True
32
+ self.crop_size: Union[int, None] = None
33
+ self.checkpoint_type: Union[CheckpointType, None] = None
34
+ self.face_alignment_type: Union[FaceAlignmentType, None] = None
35
+ self.smooth_mask_iter: Union[int, None] = None
36
+ self.smooth_mask_kernel_size: Union[int, None] = None
37
+ self.smooth_mask_threshold: Union[float, None] = None
38
+ self.face_detector_threshold: Union[float, None] = None
39
+ self.specific_latent_match_threshold: Union[float, None] = None
40
+ self.device = torch.device(config.device)
41
+
42
+ self.set_parameters(config)
43
+
44
+ # For BiSeNet and for official_224 SimSwap
45
+ self.to_tensor_normalize = transforms.Compose(
46
+ [
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
49
+ ]
50
+ )
51
+
52
+ # For SimSwap models trained with the updated code
53
+ self.to_tensor = transforms.ToTensor()
54
+
55
+ self.face_detector = get_model(
56
+ "face_detector",
57
+ device=self.device,
58
+ load_state_dice=False,
59
+ model_path=Path(config.face_detector_weights),
60
+ det_thresh=self.face_detector_threshold,
61
+ det_size=(640, 640),
62
+ mode="ffhq",
63
+ )
64
+
65
+ self.face_id_net = get_model(
66
+ "arcface",
67
+ device=self.device,
68
+ load_state_dice=False,
69
+ model_path=Path(config.face_id_weights),
70
+ )
71
+
72
+ self.bise_net = get_model(
73
+ "parsing_model",
74
+ device=self.device,
75
+ load_state_dice=True,
76
+ model_path=Path(config.parsing_model_weights),
77
+ n_classes=19,
78
+ )
79
+
80
+ gen_model = "generator_512" if self.crop_size == 512 else "generator_224"
81
+ self.simswap_net = get_model(
82
+ gen_model,
83
+ device=self.device,
84
+ load_state_dice=True,
85
+ model_path=Path(config.simswap_weights),
86
+ input_nc=3,
87
+ output_nc=3,
88
+ latent_size=512,
89
+ n_blocks=9,
90
+ deep=True if self.crop_size == 512 else False,
91
+ use_last_act=True
92
+ if self.checkpoint_type == CheckpointType.OFFICIAL_224
93
+ else False,
94
+ )
95
+
96
+ self.blend = get_model(
97
+ "blend_module",
98
+ device=self.device,
99
+ load_state_dice=False,
100
+ model_path=Path(config.blend_module_weights)
101
+ )
102
+
103
+ self.enhance_output = config.enhance_output
104
+ if config.enhance_output:
105
+ self.gfpgan_net = get_model(
106
+ "gfpgan",
107
+ device=self.device,
108
+ load_state_dice=True,
109
+ model_path=Path(config.gfpgan_weights)
110
+ )
111
+
112
+ def set_parameters(self, config) -> None:
113
+ self.set_crop_size(config.crop_size)
114
+ self.set_checkpoint_type(config.checkpoint_type)
115
+ self.set_face_alignment_type(config.face_alignment_type)
116
+ self.set_face_detector_threshold(config.face_detector_threshold)
117
+ self.set_specific_latent_match_threshold(config.specific_latent_match_threshold)
118
+ self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size)
119
+ self.set_smooth_mask_threshold(config.smooth_mask_threshold)
120
+ self.set_smooth_mask_iter(config.smooth_mask_iter)
121
+
122
+ def set_crop_size(self, crop_size: int) -> None:
123
+ if crop_size < 0:
124
+ raise "Invalid crop_size! Must be a positive value."
125
+
126
+ self.crop_size = crop_size
127
+
128
+ def set_checkpoint_type(self, checkpoint_type: str) -> None:
129
+ type = CheckpointType(checkpoint_type)
130
+ if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL):
131
+ raise "Invalid checkpoint_type! Must be one of the predefined values."
132
+
133
+ self.checkpoint_type = type
134
+
135
+ def set_face_alignment_type(self, face_alignment_type: str) -> None:
136
+ type = FaceAlignmentType(face_alignment_type)
137
+ if type not in (
138
+ FaceAlignmentType.FFHQ,
139
+ FaceAlignmentType.DEFAULT,
140
+ ):
141
+ raise "Invalid face_alignment_type! Must be one of the predefined values."
142
+
143
+ self.face_alignment_type = type
144
+
145
+ def set_face_detector_threshold(self, face_detector_threshold: float) -> None:
146
+ if face_detector_threshold < 0.0 or face_detector_threshold > 1.0:
147
+ raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]."
148
+
149
+ self.face_detector_threshold = face_detector_threshold
150
+
151
+ def set_specific_latent_match_threshold(
152
+ self, specific_latent_match_threshold: float
153
+ ) -> None:
154
+ if specific_latent_match_threshold < 0.0:
155
+ raise "Invalid specific_latent_match_th! Must be a positive value."
156
+
157
+ self.specific_latent_match_threshold = specific_latent_match_threshold
158
+
159
+ def re_initialize_soft_mask(self):
160
+ self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size,
161
+ threshold=self.smooth_mask_threshold,
162
+ iterations=self.smooth_mask_iter).to(self.device)
163
+
164
+ def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None:
165
+ if smooth_mask_kernel_size < 0:
166
+ raise "Invalid smooth_mask_kernel_size! Must be a positive value."
167
+ smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0
168
+ self.smooth_mask_kernel_size = smooth_mask_kernel_size
169
+ self.re_initialize_soft_mask()
170
+
171
+ def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None:
172
+ if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0:
173
+ raise "Invalid smooth_mask_threshold! Must be within 0...1 range."
174
+ self.smooth_mask_threshold = smooth_mask_threshold
175
+ self.re_initialize_soft_mask()
176
+
177
+ def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None:
178
+ if smooth_mask_iter < 0:
179
+ raise "Invalid smooth_mask_iter! Must be a positive value.."
180
+ self.smooth_mask_iter = smooth_mask_iter
181
+ self.re_initialize_soft_mask()
182
+
183
+ def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None],
184
+ Union[Iterable[np.ndarray], None],
185
+ np.ndarray]:
186
+ detection: Detection = self.face_detector(image)
187
+
188
+ if detection.bbox is None:
189
+ if for_id:
190
+ raise "Can't detect a face! Please change the ID image!"
191
+ return None, None, detection.score
192
+
193
+ kps = detection.key_points
194
+
195
+ if for_id:
196
+ max_score_ind = np.argmax(detection.score, axis=0)
197
+ kps = detection.key_points[max_score_ind]
198
+ kps = kps[None, ...]
199
+
200
+ align_imgs, transforms = align_face(
201
+ image,
202
+ kps,
203
+ crop_size=self.crop_size,
204
+ mode="ffhq"
205
+ if self.face_alignment_type == FaceAlignmentType.FFHQ
206
+ else "none",
207
+ )
208
+
209
+ return align_imgs, transforms, detection.score
210
+
211
+ def __call__(self, att_image: np.ndarray) -> np.ndarray:
212
+ if self.id_latent is None:
213
+ align_id_imgs, id_transforms, _ = self.run_detect_align(
214
+ self.id_image, for_id=True
215
+ )
216
+ # normalize=True, because official SimSwap model trained with normalized id_lattent
217
+ self.id_latent: torch.Tensor = self.face_id_net(
218
+ align_id_imgs, normalize=True
219
+ )
220
+
221
+ if self.specific_id_image is not None and self.specific_latent is None:
222
+ align_specific_imgs, specific_transforms, _ = self.run_detect_align(
223
+ self.specific_id_image, for_id=True
224
+ )
225
+ self.specific_latent: torch.Tensor = self.face_id_net(
226
+ align_specific_imgs, normalize=False
227
+ )
228
+
229
+ # for_id=False, because we want to get all faces
230
+ align_att_imgs, att_transforms, att_detection_score = self.run_detect_align(
231
+ att_image, for_id=False
232
+ )
233
+
234
+ if align_att_imgs is None and att_transforms is None:
235
+ return att_image
236
+
237
+ # Select specific crop from the target image
238
+ if self.specific_latent is not None:
239
+ att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False)
240
+ latent_dist = torch.mean(
241
+ F.mse_loss(
242
+ att_latent,
243
+ self.specific_latent.repeat(att_latent.shape[0], 1),
244
+ reduction="none",
245
+ ),
246
+ dim=-1,
247
+ )
248
+
249
+ att_detection_score = torch.tensor(
250
+ att_detection_score, device=latent_dist.device
251
+ )
252
+
253
+ min_index = torch.argmin(latent_dist * att_detection_score)
254
+ min_value = latent_dist[min_index]
255
+
256
+ if min_value < self.specific_latent_match_threshold:
257
+ align_att_imgs = [align_att_imgs[min_index]]
258
+ att_transforms = [att_transforms[min_index]]
259
+ else:
260
+ return att_image
261
+
262
+ swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent)
263
+
264
+ if self.enhance_output:
265
+ swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5)
266
+
267
+ # Put all crops/transformations into a batch
268
+ align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack(
269
+ [self.to_tensor_normalize(x) for x in align_att_imgs], dim=0
270
+ )
271
+ align_att_img_batch_for_parsing_model = (
272
+ align_att_img_batch_for_parsing_model.to(self.device)
273
+ )
274
+
275
+ att_transforms: torch.Tensor = torch.stack(
276
+ [torch.tensor(x).float() for x in att_transforms], dim=0
277
+ )
278
+ att_transforms = att_transforms.to(self.device, non_blocking=True)
279
+
280
+ align_att_img_batch: torch.Tensor = torch.stack(
281
+ [self.to_tensor(x) for x in align_att_imgs], dim=0
282
+ )
283
+ align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True)
284
+
285
+ # Get face masks for the attribute image
286
+ face_mask, ignore_mask_ids = self.bise_net.get_mask(
287
+ align_att_img_batch_for_parsing_model, self.crop_size
288
+ )
289
+
290
+ inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)
291
+
292
+ soft_face_mask, _ = self.smooth_mask(face_mask)
293
+
294
+ swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...]
295
+
296
+ frame_size = (att_image.shape[0], att_image.shape[1])
297
+
298
+ att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0)
299
+
300
+ target_image = kornia.geometry.transform.warp_affine(
301
+ swapped_img,
302
+ inv_att_transforms,
303
+ frame_size,
304
+ mode="bilinear",
305
+ padding_mode="border",
306
+ align_corners=True,
307
+ fill_value=torch.zeros(3),
308
+ )
309
+
310
+ soft_face_mask = kornia.geometry.transform.warp_affine(
311
+ soft_face_mask,
312
+ inv_att_transforms,
313
+ frame_size,
314
+ mode="bilinear",
315
+ padding_mode="zeros",
316
+ align_corners=True,
317
+ fill_value=torch.zeros(3),
318
+ )
319
+
320
+ result = self.blend(target_image, soft_face_mask, att_image)
321
+
322
+ return tensor2img(result)