TF Lite
ameerazam08 commited on
Commit
a5c5b03
1 Parent(s): fbeb913

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. .gitignore +199 -0
  3. README-zh.md +137 -0
  4. README.md +137 -0
  5. checkpoints/.gitkeep +0 -0
  6. data_gen/eg3d/convert_to_eg3d_convention.py +146 -0
  7. data_gen/runs/binarizer_nerf.py +335 -0
  8. data_gen/runs/nerf/process_guide.md +49 -0
  9. data_gen/runs/nerf/run.sh +51 -0
  10. data_gen/utils/mp_feature_extractors/face_landmarker.py +130 -0
  11. data_gen/utils/mp_feature_extractors/face_landmarker.task +3 -0
  12. data_gen/utils/mp_feature_extractors/mp_segmenter.py +274 -0
  13. data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite +3 -0
  14. data_gen/utils/path_converter.py +24 -0
  15. data_gen/utils/process_audio/extract_hubert.py +95 -0
  16. data_gen/utils/process_audio/extract_mel_f0.py +148 -0
  17. data_gen/utils/process_audio/resample_audio_to_16k.py +49 -0
  18. data_gen/utils/process_image/extract_lm2d.py +197 -0
  19. data_gen/utils/process_image/extract_segment_imgs.py +114 -0
  20. data_gen/utils/process_image/fit_3dmm_landmark.py +369 -0
  21. data_gen/utils/process_video/euler2quaterion.py +35 -0
  22. data_gen/utils/process_video/extract_blink.py +50 -0
  23. data_gen/utils/process_video/extract_lm2d.py +164 -0
  24. data_gen/utils/process_video/extract_segment_imgs.py +500 -0
  25. data_gen/utils/process_video/fit_3dmm_landmark.py +565 -0
  26. data_gen/utils/process_video/inpaint_torso_imgs.py +193 -0
  27. data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py +87 -0
  28. data_gen/utils/process_video/split_video_to_imgs.py +53 -0
  29. data_util/face3d_helper.py +309 -0
  30. deep_3drecon/BFM/.gitkeep +0 -0
  31. deep_3drecon/bfm_left_eye_faces.npy +3 -0
  32. deep_3drecon/bfm_right_eye_faces.npy +3 -0
  33. deep_3drecon/deep_3drecon_models/bfm.py +426 -0
  34. deep_3drecon/ncc_code.npy +3 -0
  35. deep_3drecon/secc_renderer.py +78 -0
  36. deep_3drecon/util/mesh_renderer.py +131 -0
  37. docs/prepare_env/install_guide-zh.md +35 -0
  38. docs/prepare_env/install_guide.md +34 -0
  39. docs/prepare_env/requirements.txt +75 -0
  40. inference/app_real3dportrait.py +244 -0
  41. inference/edit_secc.py +147 -0
  42. inference/infer_utils.py +154 -0
  43. inference/real3d_infer.py +542 -0
  44. insta.sh +18 -0
  45. modules/audio2motion/cnn_models.py +359 -0
  46. modules/audio2motion/flow_base.py +838 -0
  47. modules/audio2motion/multi_length_disc.py +340 -0
  48. modules/audio2motion/transformer_base.py +988 -0
  49. modules/audio2motion/transformer_models.py +208 -0
  50. modules/audio2motion/utils.py +29 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data_gen/utils/mp_feature_extractors/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
37
+ pytorch3d/.github/bundle_adjust.gif filter=lfs diff=lfs merge=lfs -text
38
+ pytorch3d/.github/camera_position_teapot.gif filter=lfs diff=lfs merge=lfs -text
39
+ pytorch3d/.github/fit_nerf.gif filter=lfs diff=lfs merge=lfs -text
40
+ pytorch3d/.github/fit_textured_volume.gif filter=lfs diff=lfs merge=lfs -text
41
+ pytorch3d/.github/implicitron_config.gif filter=lfs diff=lfs merge=lfs -text
42
+ pytorch3d/.github/nerf_project_logo.gif filter=lfs diff=lfs merge=lfs -text
43
+ pytorch3d/docs/notes/assets/batch_modes.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # big files
2
+ data_util/face_tracking/3DMM/01_MorphableModel.mat
3
+ data_util/face_tracking/3DMM/3DMM_info.npy
4
+
5
+ !/deep_3drecon/BFM/.gitkeep
6
+ deep_3drecon/BFM/Exp_Pca.bin
7
+ deep_3drecon/BFM/01_MorphableModel.mat
8
+ deep_3drecon/BFM/BFM_model_front.mat
9
+ deep_3drecon/network/FaceReconModel.pb
10
+ deep_3drecon/checkpoints/*
11
+
12
+ .vscode
13
+ ### Project ignore
14
+ /checkpoints/*
15
+ !/checkpoints/.gitkeep
16
+ /data/*
17
+ !/data/.gitkeep
18
+ infer_out
19
+ rsync
20
+ .idea
21
+ .DS_Store
22
+ bak
23
+ tmp
24
+ *.tar.gz
25
+ mos
26
+ nbs
27
+ /configs_usr/*
28
+ !/configs_usr/.gitkeep
29
+ /egs_usr/*
30
+ !/egs_usr/.gitkeep
31
+ /rnnoise
32
+ #/usr/*
33
+ #!/usr/.gitkeep
34
+ scripts_usr
35
+
36
+ # Created by .ignore support plugin (hsz.mobi)
37
+ ### Python template
38
+ # Byte-compiled / optimized / DLL files
39
+ __pycache__/
40
+ *.py[cod]
41
+ *$py.class
42
+
43
+ # C extensions
44
+ *.so
45
+
46
+ # Distribution / packaging
47
+ .Python
48
+ build/
49
+ develop-eggs/
50
+ dist/
51
+ downloads/
52
+ eggs/
53
+ .eggs/
54
+ lib/
55
+ lib64/
56
+ parts/
57
+ sdist/
58
+ var/
59
+ wheels/
60
+ pip-wheel-metadata/
61
+ share/python-wheels/
62
+ *.egg-info/
63
+ .installed.cfg
64
+ *.egg
65
+ MANIFEST
66
+
67
+ # PyInstaller
68
+ # Usually these files are written by a python script from a template
69
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
70
+ *.manifest
71
+ *.spec
72
+
73
+ # Installer logs
74
+ pip-log.txt
75
+ pip-delete-this-directory.txt
76
+
77
+ # Unit test / coverage reports
78
+ htmlcov/
79
+ .tox/
80
+ .nox/
81
+ .coverage
82
+ .coverage.*
83
+ .cache
84
+ nosetests.xml
85
+ coverage.xml
86
+ *.cover
87
+ .hypothesis/
88
+ .pytest_cache/
89
+
90
+ # Translations
91
+ *.mo
92
+ *.pot
93
+
94
+ # Django stuff:
95
+ *.log
96
+ local_settings.py
97
+ db.sqlite3
98
+ db.sqlite3-journal
99
+
100
+ # Flask stuff:
101
+ instance/
102
+ .webassets-cache
103
+
104
+ # Scrapy stuff:
105
+ .scrapy
106
+
107
+ # Sphinx documentation
108
+ docs/_build/
109
+
110
+ # PyBuilder
111
+ target/
112
+
113
+ # Jupyter Notebook
114
+ .ipynb_checkpoints
115
+
116
+ # IPython
117
+ profile_default/
118
+ ipython_config.py
119
+
120
+ # pyenv
121
+ .python-version
122
+
123
+ # pipenv
124
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
125
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
126
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
127
+ # install all needed dependencies.
128
+ #Pipfile.lock
129
+
130
+ # celery beat schedule file
131
+ celerybeat-schedule
132
+
133
+ # SageMath parsed files
134
+ *.sage.py
135
+
136
+ # Environments
137
+ .env
138
+ .venv
139
+ env/
140
+ venv/
141
+ ENV/
142
+ env.bak/
143
+ venv.bak/
144
+
145
+ # Spyder project settings
146
+ .spyderproject
147
+ .spyproject
148
+
149
+ # Rope project settings
150
+ .ropeproject
151
+
152
+ # mkdocs documentation
153
+ /site
154
+
155
+ # mypy
156
+ .mypy_cache/
157
+ .dmypy.json
158
+ dmypy.json
159
+
160
+ # Pyre type checker
161
+ .pyre/
162
+ data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm
163
+ deep_3drecon/mesh_renderer/bazel-bin
164
+ deep_3drecon/mesh_renderer/bazel-mesh_renderer
165
+ deep_3drecon/mesh_renderer/bazel-out
166
+ deep_3drecon/mesh_renderer/bazel-testlogs
167
+
168
+ .nfs*
169
+ infer_outs/*
170
+
171
+ *.pth
172
+ venv_113/*
173
+ *.pt
174
+ experiments/trials
175
+ flame_3drecon/*
176
+
177
+ temp/
178
+ /kill.sh
179
+ /datasets
180
+ data_util/imagenet_classes.txt
181
+ process_data_May.sh
182
+ /env_prepare_reproduce.md
183
+ /my_debug.py
184
+
185
+ utils/metrics/shape_predictor_68_face_landmarks.dat
186
+ *.mp4
187
+ _torchshow/
188
+ *.png
189
+ *.jpg
190
+
191
+ *.mrc
192
+
193
+ deep_3drecon/BFM/BFM_exp_idx.mat
194
+ deep_3drecon/BFM/BFM_front_idx.mat
195
+ deep_3drecon/BFM/facemodel_info.mat
196
+ deep_3drecon/BFM/index_mp468_from_mesh35709.npy
197
+ deep_3drecon/BFM/mediapipe_in_bfm53201.npy
198
+ deep_3drecon/BFM/std_exp.txt
199
+ !data/raw/examples/*
README-zh.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/Real3DPortrait
3
+ )](https://github.com/yerfor/Real3DPortrait) | [English Readme](./README.md)
4
+
5
+ 这个仓库是Real3D-Portrait的官方PyTorch实现, 用于实现单参考图(one-shot)、高视频真实度(video reality)的虚拟人视频合成。您可以访问我们的[项目页面](https://real3dportrait.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/pdf/2401.08503.pdf)以了解技术细节。
6
+
7
+ <p align="center">
8
+ <br>
9
+ <img src="assets/real3dportrait.png" width="100%"/>
10
+ <br>
11
+ </p>
12
+
13
+ # 快速上手!
14
+ ## 安装环境
15
+ 请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`real3dportrait`
16
+ ## 下载预训练与第三方模型
17
+ ### 3DMM BFM模型
18
+ 下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
19
+
20
+
21
+ 下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
22
+ ```
23
+ deep_3drecon/BFM/
24
+ ├── 01_MorphableModel.mat
25
+ ├── BFM_exp_idx.mat
26
+ ├── BFM_front_idx.mat
27
+ ├── BFM_model_front.mat
28
+ ├── Exp_Pca.bin
29
+ ├── facemodel_info.mat
30
+ ├── index_mp468_from_mesh35709.npy
31
+ ├── mediapipe_in_bfm53201.npy
32
+ └── std_exp.txt
33
+ ```
34
+
35
+ ### 预训练模型
36
+ 下载预训练的Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) 提取码: 6x4f
37
+
38
+ 下载完成后,放置全部的文件到`checkpoints`里并解压,文件结构如下:
39
+ ```
40
+ checkpoints/
41
+ ├── 240126_real3dportrait_orig
42
+ │ ├── audio2secc_vae
43
+ │ │ ├── config.yaml
44
+ │ │ └── model_ckpt_steps_400000.ckpt
45
+ │ └── secc2plane_torso_orig
46
+ │ ├── config.yaml
47
+ │ └── model_ckpt_steps_100000.ckpt
48
+ └── pretrained_ckpts
49
+ └── mit_b0.pth
50
+ ```
51
+
52
+ ## 推理测试
53
+ 我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式,并将在未来提供Google Colab方式。我们同时支持音频驱动(Audio-Driven)与视频驱动(Video-Driven):
54
+
55
+ - 音频驱动场景下,需要至少提供`source image`与`driving audio`
56
+ - 视频驱动场景下,需要至少提供`source image`与`driving expression video`
57
+
58
+ ### Gradio WebUI推理
59
+ 启动Gradio WebUI,按照提示上传素材,点击`Generate`按钮即可推理:
60
+ ```bash
61
+ python inference/app_real3dportrait.py
62
+ ```
63
+
64
+ ### 命令行推理
65
+ 首先,切换至项目根目录并启用Conda环境:
66
+ ```bash
67
+ cd <Real3DPortraitRoot>
68
+ conda activate real3dportrait
69
+ export PYTHON_PATH=./
70
+ ```
71
+ 音频驱动场景下,需要至少提供source image与driving audio,推理指令:
72
+ ```bash
73
+ python inference/real3d_infer.py \
74
+ --src_img <PATH_TO_SOURCE_IMAGE> \
75
+ --drv_aud <PATH_TO_AUDIO> \
76
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
77
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
78
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
79
+ ```
80
+ 视频驱动场景下,需要至少提供source image与driving expression video(作为drv_aud参数),推理指令:
81
+ ```bash
82
+ python inference/real3d_infer.py \
83
+ --src_img <PATH_TO_SOURCE_IMAGE> \
84
+ --drv_aud <PATH_TO_EXP_VIDEO> \
85
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
86
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
87
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
88
+ ```
89
+ 一些可选参数注释:
90
+ - `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
91
+ - `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
92
+ - `--mouth_amp` 嘴部张幅参数,值越大张幅越大
93
+ - `--map_to_init_pose` 值为`True`时,首帧的pose将被映射到source pose,后续帧也作相同变换
94
+ - `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
95
+ - `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
96
+ - `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
97
+
98
+ 指令示例:
99
+ ```bash
100
+ python inference/real3d_infer.py \
101
+ --src_img data/raw/examples/Macron.png \
102
+ --drv_aud data/raw/examples/Obama_5s.wav \
103
+ --drv_pose data/raw/examples/May_5s.mp4 \
104
+ --bg_img data/raw/examples/bg.png \
105
+ --out_name output.mp4 \
106
+ --out_mode concat_debug
107
+ ```
108
+
109
+ ## ToDo
110
+ - [x] **Release Pre-trained weights of Real3D-Portrait.**
111
+ - [x] **Release Inference Code of Real3D-Portrait.**
112
+ - [x] **Release Gradio Demo of Real3D-Portrait..**
113
+ - [ ] **Release Google Colab of Real3D-Portrait..**
114
+ - [ ] **Release Training Code of Real3D-Portrait.**
115
+
116
+ # 引用我们
117
+ 如果这个仓库对你有帮助,请考虑引用我们��工作:
118
+ ```
119
+ @article{ye2024real3d,
120
+ title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
121
+ author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
122
+ journal={arXiv preprint arXiv:2401.08503},
123
+ year={2024}
124
+ }
125
+ @article{ye2023geneface++,
126
+ title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
127
+ author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
128
+ journal={arXiv preprint arXiv:2305.00787},
129
+ year={2023}
130
+ }
131
+ @article{ye2023geneface,
132
+ title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
133
+ author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
134
+ journal={arXiv preprint arXiv:2301.13430},
135
+ year={2023}
136
+ }
137
+ ```
README.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/Real3DPortrait
3
+ )](https://github.com/yerfor/Real3DPortrait) | [中文文档](./README-zh.md)
4
+
5
+ This is the official repo of Real3D-Portrait with Pytorch implementation, for one-shot and high video reality talking portrait synthesis. You can visit our [Demo Page](https://real3dportrait.github.io/) for watching demo videos, and read our [Paper](https://arxiv.org/pdf/2401.08503.pdf) for technical details.
6
+
7
+ <p align="center">
8
+ <br>
9
+ <img src="assets/real3dportrait.png" width="100%"/>
10
+ <br>
11
+ </p>
12
+
13
+ # Quick Start!
14
+ ## Environment Installation
15
+ Please refer to [Installation Guide](docs/prepare_env/install_guide.md), prepare a Conda environment `real3dportrait`.
16
+ ## Download Pre-trained & Third-Party Models
17
+ ### 3DMM BFM Model
18
+ Download 3DMM BFM Model from [Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) with Password m9q5.
19
+
20
+
21
+ Put all the files in `deep_3drecon/BFM`, the file structure will be like this:
22
+ ```
23
+ deep_3drecon/BFM/
24
+ ├── 01_MorphableModel.mat
25
+ ├── BFM_exp_idx.mat
26
+ ├── BFM_front_idx.mat
27
+ ├── BFM_model_front.mat
28
+ ├── Exp_Pca.bin
29
+ ├── facemodel_info.mat
30
+ ├── index_mp468_from_mesh35709.npy
31
+ ├── mediapipe_in_bfm53201.npy
32
+ └── std_exp.txt
33
+ ```
34
+
35
+ ### Pre-trained Real3D-Portrait
36
+ Download Pre-trained Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) with Password 6x4f
37
+
38
+ Put the zip files in `checkpoints` and unzip them, the file structure will be like this:
39
+ ```
40
+ checkpoints/
41
+ ├── 240126_real3dportrait_orig
42
+ │ ├── audio2secc_vae
43
+ │ │ ├── config.yaml
44
+ │ │ └── model_ckpt_steps_400000.ckpt
45
+ │ └── secc2plane_torso_orig
46
+ │ ├── config.yaml
47
+ │ └── model_ckpt_steps_100000.ckpt
48
+ └── pretrained_ckpts
49
+ └── mit_b0.pth
50
+ ```
51
+
52
+ ## Inference
53
+ Currently, we provide **CLI** and **Gradio WebUI** for inference, and Google Colab will be provided in the future. We support both Audio-Driven and Video-Driven methods:
54
+
55
+ - For audio-driven, at least prepare `source image` and `driving audio`
56
+ - For video-driven, at least prepare `source image` and `driving expression video`
57
+
58
+ ### Gradio WebUI
59
+ Run Gradio WebUI demo, upload resouces in webpage,click `Generate` button to inference:
60
+ ```bash
61
+ python inference/app_real3dportrait.py
62
+ ```
63
+
64
+ ### CLI Inference
65
+ Firstly, switch to project folder and activate conda environment:
66
+ ```bash
67
+ cd <Real3DPortraitRoot>
68
+ conda activate real3dportrait
69
+ export PYTHON_PATH=./
70
+ ```
71
+ For audio-driven, provide source image and driving audio:
72
+ ```bash
73
+ python inference/real3d_infer.py \
74
+ --src_img <PATH_TO_SOURCE_IMAGE> \
75
+ --drv_aud <PATH_TO_AUDIO> \
76
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
77
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
78
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
79
+ ```
80
+ For video-driven, provide source image and driving expression video(as `--drv_aud` parameter):
81
+ ```bash
82
+ python inference/real3d_infer.py \
83
+ --src_img <PATH_TO_SOURCE_IMAGE> \
84
+ --drv_aud <PATH_TO_EXP_VIDEO> \
85
+ --drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
86
+ --bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
87
+ --out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
88
+ ```
89
+ Some optional parameters:
90
+ - `--drv_pose` provide motion pose information, default to be static poses
91
+ - `--bg_img` provide background information, default to be image extracted from source
92
+ - `--mouth_amp` mouth amplitude, higher value leads to wider mouth
93
+ - `--map_to_init_pose` when set to `True`, the initial pose will be mapped to source pose, and other poses will be equally transformed
94
+ - `--temperature` stands for the sampling temperature of audio2motion, higher for more diverse results at the expense of lower accuracy
95
+ - `--out_name` When not assigned, the results will be stored at `infer_out/tmp/`.
96
+ - `--out_mode` When `final`, only outputs the final result; when `concat_debug`, also outputs visualization of several intermediate process.
97
+
98
+ Commandline example:
99
+ ```bash
100
+ python inference/real3d_infer.py \
101
+ --src_img data/raw/examples/Macron.png \
102
+ --drv_aud data/raw/examples/Obama_5s.wav \
103
+ --drv_pose data/raw/examples/May_5s.mp4 \
104
+ --bg_img data/raw/examples/bg.png \
105
+ --out_name output.mp4 \
106
+ --out_mode concat_debug
107
+ ```
108
+
109
+ # ToDo
110
+ - [x] **Release Pre-trained weights of Real3D-Portrait.**
111
+ - [x] **Release Inference Code of Real3D-Portrait.**
112
+ - [x] **Release Gradio Demo of Real3D-Portrait..**
113
+ - [ ] **Release Google Colab of Real3D-Portrait..**
114
+ - [ ] **Release Training Code of Real3D-Portrait.**
115
+
116
+ # Citation
117
+ If you found this repo helpful to your work, please consider cite us:
118
+ ```
119
+ @article{ye2024real3d,
120
+ title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
121
+ author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
122
+ journal={arXiv preprint arXiv:2401.08503},
123
+ year={2024}
124
+ }
125
+ @article{ye2023geneface++,
126
+ title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
127
+ author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
128
+ journal={arXiv preprint arXiv:2305.00787},
129
+ year={2023}
130
+ }
131
+ @article{ye2023geneface,
132
+ title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
133
+ author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
134
+ journal={arXiv preprint arXiv:2301.13430},
135
+ year={2023}
136
+ }
137
+ ```
checkpoints/.gitkeep ADDED
File without changes
data_gen/eg3d/convert_to_eg3d_convention.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import copy
4
+ from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
5
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
6
+
7
+
8
+ def _fix_intrinsics(intrinsics):
9
+ """
10
+ intrinsics: [3,3], not batch-wise
11
+ """
12
+ # unnormalized normalized
13
+
14
+ # [[ f_x, s=0, x_0] [[ f_x/size_x, s=0, x_0/size_x=0.5]
15
+ # [ 0, f_y, y_0] -> [ 0, f_y/size_y, y_0/size_y=0.5]
16
+ # [ 0, 0, 1 ]] [ 0, 0, 1 ]]
17
+ intrinsics = np.array(intrinsics).copy()
18
+ assert intrinsics.shape == (3, 3), intrinsics
19
+ intrinsics[0,0] = 2985.29/700
20
+ intrinsics[1,1] = 2985.29/700
21
+ intrinsics[0,2] = 1/2
22
+ intrinsics[1,2] = 1/2
23
+ assert intrinsics[0,1] == 0
24
+ assert intrinsics[2,2] == 1
25
+ assert intrinsics[1,0] == 0
26
+ assert intrinsics[2,0] == 0
27
+ assert intrinsics[2,1] == 0
28
+ return intrinsics
29
+
30
+ # Used in original submission
31
+ def _fix_pose_orig(pose):
32
+ """
33
+ pose: [4,4], not batch-wise
34
+ """
35
+ pose = np.array(pose).copy()
36
+ location = pose[:3, 3]
37
+ radius = np.linalg.norm(location)
38
+ pose[:3, 3] = pose[:3, 3]/radius * 2.7
39
+ return pose
40
+
41
+
42
+ def get_eg3d_convention_camera_pose_intrinsic(item):
43
+ """
44
+ item: a dict during binarize
45
+
46
+ """
47
+ if item['euler'].ndim == 1:
48
+ angle = convert_to_tensor(copy.copy(item['euler']))
49
+ trans = copy.deepcopy(item['trans'])
50
+
51
+ # handle the difference of euler axis between eg3d and ours
52
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
53
+ # angle += torch.tensor([0, 3.1415926535, 3.1415926535], device=angle.device)
54
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
55
+ trans[2] += -10
56
+ c = -np.dot(R, trans)
57
+ pose = np.eye(4)
58
+ pose[:3,:3] = R
59
+ c *= 0.27 # normalize camera radius
60
+ c[1] += 0.006 # additional offset used in submission
61
+ c[2] += 0.161 # additional offset used in submission
62
+ pose[0,3] = c[0]
63
+ pose[1,3] = c[1]
64
+ pose[2,3] = c[2]
65
+
66
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
67
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
68
+ pp = 512#112
69
+ w = 1024#224
70
+ h = 1024#224
71
+
72
+ K = np.eye(3)
73
+ K[0][0] = focal
74
+ K[1][1] = focal
75
+ K[0][2] = w/2.0
76
+ K[1][2] = h/2.0
77
+ convention_K = _fix_intrinsics(K)
78
+
79
+ Rot = np.eye(3)
80
+ Rot[0, 0] = 1
81
+ Rot[1, 1] = -1
82
+ Rot[2, 2] = -1
83
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot) # permute axes
84
+ convention_pose = _fix_pose_orig(pose)
85
+
86
+ item['c2w'] = pose
87
+ item['convention_c2w'] = convention_pose
88
+ item['intrinsics'] = convention_K
89
+ return item
90
+ else:
91
+ num_samples = len(item['euler'])
92
+ eulers_all = convert_to_tensor(copy.deepcopy(item['euler'])) # [B, 3]
93
+ trans_all = copy.deepcopy(item['trans']) # [B, 3]
94
+
95
+ # handle the difference of euler axis between eg3d and ours
96
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
97
+ # eulers_all += torch.tensor([0, 3.1415926535, 3.1415926535], device=eulers_all.device).unsqueeze(0).repeat([eulers_all.shape[0],1])
98
+
99
+ intrinsics = []
100
+ poses = []
101
+ convention_poses = []
102
+ for i in range(num_samples):
103
+ angle = eulers_all[i]
104
+ trans = trans_all[i]
105
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
106
+ trans[2] += -10
107
+ c = -np.dot(R, trans)
108
+ pose = np.eye(4)
109
+ pose[:3,:3] = R
110
+ c *= 0.27 # normalize camera radius
111
+ c[1] += 0.006 # additional offset used in submission
112
+ c[2] += 0.161 # additional offset used in submission
113
+ pose[0,3] = c[0]
114
+ pose[1,3] = c[1]
115
+ pose[2,3] = c[2]
116
+
117
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
118
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
119
+ pp = 512#112
120
+ w = 1024#224
121
+ h = 1024#224
122
+
123
+ K = np.eye(3)
124
+ K[0][0] = focal
125
+ K[1][1] = focal
126
+ K[0][2] = w/2.0
127
+ K[1][2] = h/2.0
128
+ convention_K = _fix_intrinsics(K)
129
+ intrinsics.append(convention_K)
130
+
131
+ Rot = np.eye(3)
132
+ Rot[0, 0] = 1
133
+ Rot[1, 1] = -1
134
+ Rot[2, 2] = -1
135
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot)
136
+ convention_pose = _fix_pose_orig(pose)
137
+ convention_poses.append(convention_pose)
138
+ poses.append(pose)
139
+
140
+ intrinsics = np.stack(intrinsics) # [B, 3, 3]
141
+ poses = np.stack(poses) # [B, 4, 4]
142
+ convention_poses = np.stack(convention_poses) # [B, 4, 4]
143
+ item['intrinsics'] = intrinsics
144
+ item['c2w'] = poses
145
+ item['convention_c2w'] = convention_poses
146
+ return item
data_gen/runs/binarizer_nerf.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import math
4
+ import json
5
+ import imageio
6
+ import torch
7
+ import tqdm
8
+ import cv2
9
+
10
+ from data_util.face3d_helper import Face3DHelper
11
+ from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans
12
+ from data_gen.utils.process_video.euler2quaterion import euler2quaterion, quaterion2euler
13
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
14
+
15
+
16
+ def euler2rot(euler_angle):
17
+ batch_size = euler_angle.shape[0]
18
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
19
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
20
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
21
+ one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
22
+ zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
23
+ rot_x = torch.cat((
24
+ torch.cat((one, zero, zero), 1),
25
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
26
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
27
+ ), 2)
28
+ rot_y = torch.cat((
29
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
30
+ torch.cat((zero, one, zero), 1),
31
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
32
+ ), 2)
33
+ rot_z = torch.cat((
34
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
35
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
36
+ torch.cat((zero, zero, one), 1)
37
+ ), 2)
38
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
39
+
40
+
41
+ def rot2euler(rot_mat):
42
+ batch_size = len(rot_mat)
43
+ # we assert that y in in [-0.5pi, 0.5pi]
44
+ cos_y = torch.sqrt(rot_mat[:, 1, 2] * rot_mat[:, 1, 2] + rot_mat[:, 2, 2] * rot_mat[:, 2, 2])
45
+ theta_x = torch.atan2(-rot_mat[:, 1, 2], rot_mat[:, 2, 2])
46
+ theta_y = torch.atan2(rot_mat[:, 2, 0], cos_y)
47
+ theta_z = torch.atan2(rot_mat[:, 0, 1], rot_mat[:, 0, 0])
48
+ euler_angles = torch.zeros([batch_size, 3])
49
+ euler_angles[:, 0] = theta_x
50
+ euler_angles[:, 1] = theta_y
51
+ euler_angles[:, 2] = theta_z
52
+ return euler_angles
53
+
54
+ index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
55
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
56
+
57
+ def plot_lm2d(lm2d):
58
+ WH = 512
59
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
60
+
61
+ for i in range(len(lm2d)):
62
+ x, y = lm2d[i]
63
+ color = (255,0,0)
64
+ img = cv2.circle(img, center=(int(x),int(y)), radius=3, color=color, thickness=-1)
65
+ font = cv2.FONT_HERSHEY_SIMPLEX
66
+ for i in range(len(lm2d)):
67
+ x, y = lm2d[i]
68
+ img = cv2.putText(img, f"{i}", org=(int(x),int(y)), fontFace=font, fontScale=0.3, color=(255,0,0))
69
+ return img
70
+
71
+ def get_face_rect(lms, h, w):
72
+ """
73
+ lms: [68, 2]
74
+ h, w: int
75
+ return: [4,]
76
+ """
77
+ assert len(lms) == 68
78
+ # min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
79
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
80
+ cx = int((min_x+max_x)/2.0)
81
+ cy = int(lms[27, 1])
82
+ h_w = int((max_x-cx)*1.5)
83
+ h_h = int((lms[8, 1]-cy)*1.15)
84
+ rect_x = cx - h_w
85
+ rect_y = cy - h_h
86
+ if rect_x < 0:
87
+ rect_x = 0
88
+ if rect_y < 0:
89
+ rect_y = 0
90
+ rect_w = min(w-1-rect_x, 2*h_w)
91
+ rect_h = min(h-1-rect_y, 2*h_h)
92
+ # rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
93
+ # rect = [rect_x, rect_y, rect_w, rect_h]
94
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
95
+ return rect # this x is width, y is height
96
+
97
+ def get_lip_rect(lms, h, w):
98
+ """
99
+ lms: [68, 2]
100
+ h, w: int
101
+ return: [4,]
102
+ """
103
+ # this x is width, y is height
104
+ # for lms, lms[:, 0] is width, lms[:, 1] is height
105
+ assert len(lms) == 68
106
+ lips = slice(48, 60)
107
+ lms = lms[lips]
108
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
109
+ min_y, max_y = np.min(lms[:, 1]), np.max(lms[:, 1])
110
+ cx = int((min_x+max_x)/2.0)
111
+ cy = int((min_y+max_y)/2.0)
112
+ h_w = int((max_x-cx)*1.2)
113
+ h_h = int((max_y-cy)*1.2)
114
+
115
+ h_w = max(h_w, h_h)
116
+ h_h = h_w
117
+
118
+ rect_x = cx - h_w
119
+ rect_y = cy - h_h
120
+ rect_w = 2*h_w
121
+ rect_h = 2*h_h
122
+ if rect_x < 0:
123
+ rect_x = 0
124
+ if rect_y < 0:
125
+ rect_y = 0
126
+
127
+ if rect_x + rect_w > w:
128
+ rect_x = w - rect_w
129
+ if rect_y + rect_h > h:
130
+ rect_y = h - rect_h
131
+
132
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
133
+ return rect # this x is width, y is height
134
+
135
+
136
+ # def get_lip_rect(lms, h, w):
137
+ # """
138
+ # lms: [68, 2]
139
+ # h, w: int
140
+ # return: [4,]
141
+ # """
142
+ # assert len(lms) == 68
143
+ # lips = slice(48, 60)
144
+ # # this x is width, y is height
145
+ # xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
146
+ # ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
147
+ # # padding to H == W
148
+ # cx = (xmin + xmax) // 2
149
+ # cy = (ymin + ymax) // 2
150
+ # l = max(xmax - xmin, ymax - ymin) // 2
151
+ # xmin = max(0, cx - l)
152
+ # xmax = min(h, cx + l)
153
+ # ymin = max(0, cy - l)
154
+ # ymax = min(w, cy + l)
155
+ # lip_rect = [xmin, xmax, ymin, ymax]
156
+ # return lip_rect
157
+
158
+ def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'):
159
+ """
160
+ conds: [b, t=16, h=29]
161
+ idx: long, time index of the selected frame
162
+ """
163
+ idx = max(0, idx)
164
+ idx = min(idx, conds.shape[0]-1)
165
+ smo_half_win_size = smo_win_size//2
166
+ left_i = idx - smo_half_win_size
167
+ right_i = idx + (smo_win_size - smo_half_win_size)
168
+ pad_left, pad_right = 0, 0
169
+ if left_i < 0:
170
+ pad_left = -left_i
171
+ left_i = 0
172
+ if right_i > conds.shape[0]:
173
+ pad_right = right_i - conds.shape[0]
174
+ right_i = conds.shape[0]
175
+ conds_win = conds[left_i:right_i]
176
+ if pad_left > 0:
177
+ if pad_option == 'zero':
178
+ conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0)
179
+ elif pad_option == 'edge':
180
+ edge_value = conds[0][np.newaxis, ...]
181
+ conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0)
182
+ else:
183
+ raise NotImplementedError
184
+ if pad_right > 0:
185
+ if pad_option == 'zero':
186
+ conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0)
187
+ elif pad_option == 'edge':
188
+ edge_value = conds[-1][np.newaxis, ...]
189
+ conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0)
190
+ else:
191
+ raise NotImplementedError
192
+ assert conds_win.shape[0] == smo_win_size
193
+ return conds_win
194
+
195
+
196
+ def load_processed_data(processed_dir):
197
+ # load necessary files
198
+ background_img_name = os.path.join(processed_dir, "bg.jpg")
199
+ assert os.path.exists(background_img_name)
200
+ head_img_dir = os.path.join(processed_dir, "head_imgs")
201
+ torso_img_dir = os.path.join(processed_dir, "inpaint_torso_imgs")
202
+ gt_img_dir = os.path.join(processed_dir, "gt_imgs")
203
+
204
+ hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy")
205
+ mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy")
206
+ coeff_npy_name = os.path.join(processed_dir, "coeff_fit_mp.npy")
207
+ lm2d_npy_name = os.path.join(processed_dir, "lms_2d.npy")
208
+
209
+ ret_dict = {}
210
+
211
+ ret_dict['bg_img'] = imageio.imread(background_img_name)
212
+ ret_dict['H'], ret_dict['W'] = ret_dict['bg_img'].shape[:2]
213
+ ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = face_model.focal, face_model.center, face_model.center
214
+
215
+ print("loading lm2d coeff ...")
216
+ lm2d_arr = np.load(lm2d_npy_name)
217
+ face_rect_lst = []
218
+ lip_rect_lst = []
219
+ for lm2d in lm2d_arr:
220
+ if len(lm2d) in [468, 478]:
221
+ lm2d = lm2d[index_lm68_from_lm468]
222
+ face_rect = get_face_rect(lm2d, ret_dict['H'], ret_dict['W'])
223
+ lip_rect = get_lip_rect(lm2d, ret_dict['H'], ret_dict['W'])
224
+ face_rect_lst.append(face_rect)
225
+ lip_rect_lst.append(lip_rect)
226
+ face_rects = np.stack(face_rect_lst, axis=0) # [T, 4]
227
+
228
+ print("loading fitted 3dmm coeff ...")
229
+ coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist()
230
+ identity_arr = coeff_dict['id']
231
+ exp_arr = coeff_dict['exp']
232
+ ret_dict['id'] = identity_arr
233
+ ret_dict['exp'] = exp_arr
234
+ euler_arr = ret_dict['euler'] = coeff_dict['euler']
235
+ trans_arr = ret_dict['trans'] = coeff_dict['trans']
236
+ print("calculating lm3d ...")
237
+ idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy().reshape([-1, 68*3])
238
+ len_motion = len(idexp_lm3d_arr)
239
+ video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0)
240
+ video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0)
241
+ ret_dict['idexp_lm3d'] = idexp_lm3d_arr
242
+ ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean
243
+ ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std
244
+
245
+ # now we convert the euler_trans from deep3d convention to adnerf convention
246
+ eulers = torch.FloatTensor(euler_arr)
247
+ trans = torch.FloatTensor(trans_arr)
248
+ rots = face_model.compute_rotation(eulers) # rotation matrix is a better intermediate for convention-transplan than euler
249
+
250
+ # handle the camera pose to geneface's convention
251
+ trans[:, 2] = 10 - trans[:, 2] # 抵消fit阶段的to_camera操作,即trans[...,2] = 10 - trans[...,2]
252
+ rots = rots.permute(0, 2, 1)
253
+ trans[:, 2] = - trans[:,2] # 因为intrinsic proj不同
254
+ # below is the NeRF camera preprocessing strategy, see `save_transforms` in data_util/process.py
255
+ trans = trans / 10.0
256
+ rots_inv = rots.permute(0, 2, 1)
257
+ trans_inv = - torch.bmm(rots_inv, trans.unsqueeze(2))
258
+
259
+ pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat([len_motion, 1, 1]) # [T, 4, 4]
260
+ pose[:, :3, :3] = rots_inv
261
+ pose[:, :3, 3] = trans_inv[:, :, 0]
262
+ c2w_transform_matrices = pose.numpy()
263
+
264
+ # process the audio features used for postnet training
265
+ print("loading hubert ...")
266
+ hubert_features = np.load(hubert_npy_name)
267
+ print("loading Mel and F0 ...")
268
+ mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist()
269
+
270
+ ret_dict['hubert'] = hubert_features
271
+ ret_dict['mel'] = mel_f0_features['mel']
272
+ ret_dict['f0'] = mel_f0_features['f0']
273
+
274
+ # obtaining train samples
275
+ frame_indices = list(range(len_motion))
276
+ num_train = len_motion // 11 * 10
277
+ train_indices = frame_indices[:num_train]
278
+ val_indices = frame_indices[num_train:]
279
+
280
+ for split in ['train', 'val']:
281
+ if split == 'train':
282
+ indices = train_indices
283
+ samples = []
284
+ ret_dict['train_samples'] = samples
285
+ elif split == 'val':
286
+ indices = val_indices
287
+ samples = []
288
+ ret_dict['val_samples'] = samples
289
+
290
+ for idx in indices:
291
+ sample = {}
292
+ sample['idx'] = idx
293
+ sample['head_img_fname'] = os.path.join(head_img_dir,f"{idx:08d}.png")
294
+ sample['torso_img_fname'] = os.path.join(torso_img_dir,f"{idx:08d}.png")
295
+ sample['gt_img_fname'] = os.path.join(gt_img_dir,f"{idx:08d}.jpg")
296
+ # assert os.path.exists(sample['head_img_fname']) and os.path.exists(sample['torso_img_fname']) and os.path.exists(sample['gt_img_fname'])
297
+ sample['face_rect'] = face_rects[idx]
298
+ sample['lip_rect'] = lip_rect_lst[idx]
299
+ sample['c2w'] = c2w_transform_matrices[idx]
300
+ samples.append(sample)
301
+ return ret_dict
302
+
303
+
304
+ class Binarizer:
305
+ def __init__(self):
306
+ self.data_dir = 'data/'
307
+
308
+ def parse(self, video_id):
309
+ processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id)
310
+ binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id)
311
+ out_fname = os.path.join(binary_dir, "trainval_dataset.npy")
312
+ os.makedirs(binary_dir, exist_ok=True)
313
+ ret = load_processed_data(processed_dir)
314
+ mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy')
315
+ mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist()
316
+ ret.update(mel_f0_dict)
317
+ np.save(out_fname, ret, allow_pickle=True)
318
+
319
+
320
+
321
+ if __name__ == '__main__':
322
+ from argparse import ArgumentParser
323
+ parser = ArgumentParser()
324
+ parser.add_argument('--video_id', type=str, default='May', help='')
325
+ args = parser.parse_args()
326
+ ### Process Single Long Audio for NeRF dataset
327
+ video_id = args.video_id
328
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
329
+ camera_distance=10, focal=1015)
330
+ face_model.to("cpu")
331
+ face3d_helper = Face3DHelper()
332
+
333
+ binarizer = Binarizer()
334
+ binarizer.parse(video_id)
335
+ print(f"Binarization for {video_id} Done!")
data_gen/runs/nerf/process_guide.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 温馨提示:第一次执行可以先一步步跑完下面的命令行,把环境跑通后,之后可以直接运行同目录的run.sh,一键完成下面的所有步骤。
2
+
3
+ # Step0. 将视频Crop到512x512分辨率,25FPS,确保每一帧都有目标人脸
4
+ ```
5
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 data/raw/videos/${VIDEO_ID}_512.mp4
6
+ mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
7
+ mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
8
+ ```
9
+ # step1: 提取音频特征, 如mel, f0, hubuert, esperanto
10
+ ```
11
+ export CUDA_VISIBLE_DEVICES=0
12
+ export VIDEO_ID=May
13
+ mkdir -p data/processed/videos/${VIDEO_ID}
14
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 data/processed/videos/${VIDEO_ID}/aud.wav
15
+ python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
16
+ python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
17
+ ```
18
+
19
+ # Step2. 提取图片
20
+ ```
21
+ export VIDEO_ID=May
22
+ export CUDA_VISIBLE_DEVICES=0
23
+ mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
24
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
25
+ python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
26
+ ```
27
+
28
+ # Step3. 提取lm2d_mediapipe
29
+ ### 提取2D landmark用于之后Fit 3DMM
30
+ ### num_workers是本机上的CPU worker数量;total_process是使用的机器数;process_id是本机的编号
31
+
32
+ ```
33
+ export VIDEO_ID=May
34
+ python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
35
+ ```
36
+
37
+ # Step3. fit 3dmm
38
+ ```
39
+ export VIDEO_ID=May
40
+ export CUDA_VISIBLE_DEVICES=0
41
+ python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
42
+ ```
43
+
44
+ # Step4. Binarize
45
+ ```
46
+ export VIDEO_ID=May
47
+ python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
48
+ ```
49
+ 可以看到在`data/binary/videos/Mayssss`目录下得到了数据集。
data_gen/runs/nerf/run.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # usage: CUDA_VISIBLE_DEVICES=0 bash data_gen/runs/nerf/run.sh <VIDEO_ID>
2
+ # please place video to data/raw/videos/${VIDEO_ID}.mp4
3
+ VIDEO_ID=$1
4
+ echo Processing $VIDEO_ID
5
+
6
+ echo Resizing the video to 512x512
7
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y data/raw/videos/${VIDEO_ID}_512.mp4
8
+ mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
9
+ mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
10
+ echo Done
11
+ echo The old video is moved to data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
12
+
13
+ echo mkdir -p data/processed/videos/${VIDEO_ID}
14
+ mkdir -p data/processed/videos/${VIDEO_ID}
15
+ echo Done
16
+
17
+ # extract audio file from the training video
18
+ echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
19
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
20
+ echo Done
21
+
22
+ # extract hubert_mel_f0 from audio
23
+ echo python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
24
+ python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
25
+ echo python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
26
+ python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
27
+ echo Done
28
+
29
+ # extract segment images
30
+ echo mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
31
+ mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
32
+ echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
33
+ ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
34
+ echo Done
35
+
36
+ echo python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
37
+ python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
38
+ echo Done
39
+
40
+ echo python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
41
+ python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
42
+ echo Done
43
+
44
+ pkill -f void*
45
+ echo python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
46
+ python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
47
+ echo Done
48
+
49
+ echo python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
50
+ python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
51
+ echo Done
data_gen/utils/mp_feature_extractors/face_landmarker.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ from mediapipe.tasks import python
3
+ from mediapipe.tasks.python import vision
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+ import copy
8
+
9
+ # simplified mediapipe ldm at https://github.com/k-m-irfan/simplified_mediapipe_face_landmarks
10
+ index_lm141_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [468,469,470,471,472] + [473,474,475,476,477] + [64,4,294]
11
+ # lm141 without iris
12
+ index_lm131_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [64,4,294]
13
+
14
+ # face alignment lm68
15
+ index_lm68_from_lm478 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
16
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
17
+ # used for weights for key parts
18
+ unmatch_mask_from_lm478 = [ 93, 127, 132, 234, 323, 356, 361, 454]
19
+ index_eye_from_lm478 = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
20
+ index_innerlip_from_lm478 = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
21
+ index_outerlip_from_lm478 = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
22
+ index_withinmouth_from_lm478 = [76, 62] + [184, 183, 74, 72, 73, 41, 72, 38, 11, 12, 302, 268, 303, 271, 304, 272, 408, 407] + [292, 306] + [325, 307, 319, 320, 403, 404, 316, 315, 15, 16, 86, 85, 179, 180, 89, 90, 96, 77]
23
+ index_mouth_from_lm478 = index_innerlip_from_lm478 + index_outerlip_from_lm478 + index_withinmouth_from_lm478
24
+
25
+ index_yaw_from_lm68 = list(range(0, 17))
26
+ index_brow_from_lm68 = list(range(17, 27))
27
+ index_nose_from_lm68 = list(range(27, 36))
28
+ index_eye_from_lm68 = list(range(36, 48))
29
+ index_mouth_from_lm68 = list(range(48, 68))
30
+
31
+
32
+ def read_video_to_frames(video_name):
33
+ frames = []
34
+ cap = cv2.VideoCapture(video_name)
35
+ while cap.isOpened():
36
+ ret, frame_bgr = cap.read()
37
+ if frame_bgr is None:
38
+ break
39
+ frames.append(frame_bgr)
40
+ frames = np.stack(frames)
41
+ frames = np.flip(frames, -1) # BGR ==> RGB
42
+ return frames
43
+
44
+ class MediapipeLandmarker:
45
+ def __init__(self):
46
+ model_path = 'data_gen/utils/mp_feature_extractors/face_landmarker.task'
47
+ if not os.path.exists(model_path):
48
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
49
+ print("downloading face_landmarker model from mediapipe...")
50
+ model_url = 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task'
51
+ os.system(f"wget {model_url}")
52
+ os.system(f"mv face_landmarker.task {model_path}")
53
+ print("download success")
54
+ base_options = python.BaseOptions(model_asset_path=model_path)
55
+ self.image_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
56
+ running_mode=vision.RunningMode.IMAGE, # IMAGE, VIDEO, LIVE_STREAM
57
+ num_faces=1)
58
+ self.video_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
59
+ running_mode=vision.RunningMode.VIDEO, # IMAGE, VIDEO, LIVE_STREAM
60
+ num_faces=1)
61
+
62
+ def extract_lm478_from_img_name(self, img_name):
63
+ img = cv2.imread(img_name)
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
+ img_lm478 = self.extract_lm478_from_img(img)
66
+ return img_lm478
67
+
68
+ def extract_lm478_from_img(self, img):
69
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
70
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=img.astype(np.uint8))
71
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
72
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
73
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
74
+ H, W, _ = img.shape
75
+ img_lm478 = np.array(img_face_landmarks)[:, :2] * np.array([W, H]).reshape([1,2]) # [478, 2]
76
+ return img_lm478
77
+
78
+ def extract_lm478_from_video_name(self, video_name, fps=25, anti_smooth_factor=2):
79
+ frames = read_video_to_frames(video_name)
80
+ img_lm478, vid_lm478 = self.extract_lm478_from_frames(frames, fps, anti_smooth_factor)
81
+ return img_lm478, vid_lm478
82
+
83
+ def extract_lm478_from_frames(self, frames, fps=25, anti_smooth_factor=20):
84
+ """
85
+ frames: RGB, uint8
86
+ anti_smooth_factor: float, 对video模式的interval进行修改, 1代表无修改, 越大越接近image mode
87
+ """
88
+ img_mpldms = []
89
+ vid_mpldms = []
90
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
91
+ vid_landmarker = vision.FaceLandmarker.create_from_options(self.video_mode_options)
92
+
93
+ for i in range(len(frames)):
94
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=frames[i].astype(np.uint8))
95
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
96
+ vid_face_landmarker_result = vid_landmarker.detect_for_video(image=frame, timestamp_ms=int((1000/fps)*anti_smooth_factor*i))
97
+ try:
98
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
99
+ vid_ldm_i = vid_face_landmarker_result.face_landmarks[0]
100
+ except:
101
+ print(f"Warning: failed detect ldm in idx={i}, use previous frame results.")
102
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
103
+ vid_face_landmarks = np.array([[l.x, l.y, l.z] for l in vid_ldm_i])
104
+ img_mpldms.append(img_face_landmarks)
105
+ vid_mpldms.append(vid_face_landmarks)
106
+ img_lm478 = np.stack(img_mpldms)[..., :2]
107
+ vid_lm478 = np.stack(vid_mpldms)[..., :2]
108
+ bs, H, W, _ = frames.shape
109
+ img_lm478 = np.array(img_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
110
+ vid_lm478 = np.array(vid_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
111
+ return img_lm478, vid_lm478
112
+
113
+ def combine_vid_img_lm478_to_lm68(self, img_lm478, vid_lm478):
114
+ img_lm68 = img_lm478[:, index_lm68_from_lm478]
115
+ vid_lm68 = vid_lm478[:, index_lm68_from_lm478]
116
+ combined_lm68 = copy.deepcopy(img_lm68)
117
+ combined_lm68[:, index_yaw_from_lm68] = vid_lm68[:, index_yaw_from_lm68]
118
+ combined_lm68[:, index_brow_from_lm68] = vid_lm68[:, index_brow_from_lm68]
119
+ combined_lm68[:, index_nose_from_lm68] = vid_lm68[:, index_nose_from_lm68]
120
+ return combined_lm68
121
+
122
+ def combine_vid_img_lm478_to_lm478(self, img_lm478, vid_lm478):
123
+ combined_lm478 = copy.deepcopy(vid_lm478)
124
+ combined_lm478[:, index_mouth_from_lm478] = img_lm478[:, index_mouth_from_lm478]
125
+ combined_lm478[:, index_eye_from_lm478] = img_lm478[:, index_eye_from_lm478]
126
+ return combined_lm478
127
+
128
+ if __name__ == '__main__':
129
+ landmarker = MediapipeLandmarker()
130
+ ret = landmarker.extract_lm478_from_video_name("00000.mp4")
data_gen/utils/mp_feature_extractors/face_landmarker.task ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
3
+ size 3758596
data_gen/utils/mp_feature_extractors/mp_segmenter.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import numpy as np
4
+ import tqdm
5
+ import mediapipe as mp
6
+ import torch
7
+ from mediapipe.tasks import python
8
+ from mediapipe.tasks.python import vision
9
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
10
+ from utils.commons.tensor_utils import convert_to_np
11
+ from sklearn.neighbors import NearestNeighbors
12
+
13
+ def scatter_np(condition_img, classSeg=5):
14
+ # def scatter(condition_img, classSeg=19, label_size=(512, 512)):
15
+ batch, c, height, width = condition_img.shape
16
+ # if height != label_size[0] or width != label_size[1]:
17
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
18
+ input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_)
19
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
20
+ np.put_along_axis(input_label, condition_img, 1, 1)
21
+ return input_label
22
+
23
+ def scatter(condition_img, classSeg=19):
24
+ # def scatter(condition_img, classSeg=19, label_size=(512, 512)):
25
+ batch, c, height, width = condition_img.size()
26
+ # if height != label_size[0] or width != label_size[1]:
27
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
28
+ input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device)
29
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
30
+ return input_label.scatter_(1, condition_img.long(), 1)
31
+
32
+ def encode_segmap_mask_to_image(segmap):
33
+ # rgb
34
+ _,h,w = segmap.shape
35
+ encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255
36
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
37
+ for i, color in enumerate(colors):
38
+ mask = segmap[i].astype(int)
39
+ index = np.where(mask != 0)
40
+ encoded_img[index[0], index[1], :] = np.array(color)
41
+ return encoded_img.astype(np.uint8)
42
+
43
+ def decode_segmap_mask_from_image(encoded_img):
44
+ # rgb
45
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
46
+ bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
47
+ hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
48
+ body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255)
49
+ face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
50
+ clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0)
51
+ others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
52
+ segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0)
53
+ return segmap.astype(np.uint8)
54
+
55
+ def read_video_frame(video_name, frame_id):
56
+ # https://blog.csdn.net/bby1987/article/details/108923361
57
+ # frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数
58
+ # fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率
59
+ # width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度
60
+ # height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度
61
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置
62
+ # video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置
63
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0
64
+ # video_capture.release()
65
+ vr = cv2.VideoCapture(video_name)
66
+ vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
67
+ _, frame = vr.read()
68
+ return frame
69
+
70
+ def decode_segmap_mask_from_segmap_video_frame(video_frame):
71
+ # video_frame: 0~255 BGR, obtained by read_video_frame
72
+ def assign_values(array):
73
+ remainder = array % 40 # 计算数组中每个值与40的余数
74
+ assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder))
75
+ return assigned_values
76
+ segmap = video_frame.mean(-1)
77
+ segmap = assign_values(segmap) // 40 # [H, W] with value 0~5
78
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
79
+ return segmap.astype(np.uint8)
80
+
81
+ def extract_background(img_lst, segmap_lst=None):
82
+ """
83
+ img_lst: list of rgb ndarray
84
+ """
85
+ # only use 1/20 images
86
+ num_frames = len(img_lst)
87
+ img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1]
88
+
89
+ if segmap_lst is not None:
90
+ segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1]
91
+ assert len(img_lst) == len(segmap_lst)
92
+ # get H/W
93
+ h, w = img_lst[0].shape[:2]
94
+
95
+ # nearest neighbors
96
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
97
+ distss = []
98
+ for idx, img in enumerate(img_lst):
99
+ if segmap_lst is not None:
100
+ segmap = segmap_lst[idx]
101
+ else:
102
+ segmap = seg_model._cal_seg_map(img)
103
+ bg = (segmap[0]).astype(bool)
104
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
105
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
106
+ dists, _ = nbrs.kneighbors(all_xys)
107
+ distss.append(dists)
108
+
109
+ distss = np.stack(distss)
110
+ max_dist = np.max(distss, 0)
111
+ max_id = np.argmax(distss, 0)
112
+
113
+ bc_pixs = max_dist > 10 # 5
114
+ bc_pixs_id = np.nonzero(bc_pixs)
115
+ bc_ids = max_id[bc_pixs]
116
+
117
+ num_pixs = distss.shape[1]
118
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
119
+
120
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
121
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
122
+ bg_img = bg_img.reshape(h, w, 3)
123
+
124
+ max_dist = max_dist.reshape(h, w)
125
+ bc_pixs = max_dist > 10 # 5
126
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
127
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
128
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
129
+ distances, indices = nbrs.kneighbors(bg_xys)
130
+ bg_fg_xys = fg_xys[indices[:, 0]]
131
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
132
+ return bg_img
133
+
134
+
135
+ class MediapipeSegmenter:
136
+ def __init__(self):
137
+ model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite'
138
+ if not os.path.exists(model_path):
139
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
140
+ print("downloading segmenter model from mediapipe...")
141
+ os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite")
142
+ os.system(f"mv selfie_multiclass_256x256.tflite {model_path}")
143
+ print("download success")
144
+ base_options = python.BaseOptions(model_asset_path=model_path)
145
+ self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True)
146
+ self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True)
147
+
148
+ def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True, debug_fill=False):
149
+ segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter
150
+ assert return_onehot_mask or return_segmap_image # you should at least return one
151
+ segmap_masks = []
152
+ segmap_images = []
153
+ for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."):
154
+ # for i in range(len(imgs)):
155
+ img = imgs[i]
156
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
157
+ out = segmenter.segment_for_video(mp_image, 40 * i)
158
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
159
+ if debug_fill:
160
+ # print(f'segmap {segmap}')
161
+ for x in range(-80 + 1, 0):
162
+ for y in range(200, 350):
163
+ segmap[x][y] = 4
164
+
165
+ if return_onehot_mask:
166
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
167
+ segmap_masks.append(segmap_mask)
168
+ if return_segmap_image:
169
+ segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
170
+ segmap_image = (segmap_image * 40).astype(np.uint8)
171
+ segmap_images.append(segmap_image)
172
+
173
+ if return_onehot_mask and return_segmap_image:
174
+ return segmap_masks, segmap_images
175
+ elif return_onehot_mask:
176
+ return segmap_masks
177
+ elif return_segmap_image:
178
+ return segmap_images
179
+
180
+ def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True):
181
+ """
182
+ segmenter: vision.ImageSegmenter.create_from_options(options)
183
+ img: numpy, [H, W, 3], 0~255
184
+ segmap: [C, H, W]
185
+ 0 - background
186
+ 1 - hair
187
+ 2 - body-skin
188
+ 3 - face-skin
189
+ 4 - clothes
190
+ 5 - others (accessories)
191
+ """
192
+ assert img.ndim == 3
193
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
194
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
195
+ out = segmenter.segment(image)
196
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
197
+ if return_onehot_mask:
198
+ segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
199
+ return segmap
200
+
201
+ def _seg_out_img_with_segmap(self, img, segmap, mode='head'):
202
+ """
203
+ img: [h,w,c], img is in 0~255, np
204
+ """
205
+ #
206
+ img = copy.deepcopy(img)
207
+ if mode == 'head':
208
+ selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others
209
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
210
+ # selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5
211
+ elif mode == 'person':
212
+ selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5
213
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
214
+ elif mode == 'torso':
215
+ selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5
216
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
217
+ elif mode == 'torso_with_bg':
218
+ selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5
219
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
220
+ elif mode == 'bg':
221
+ selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background
222
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
223
+ elif mode == 'full':
224
+ pass
225
+ else:
226
+ raise NotImplementedError()
227
+ return img, selected_mask
228
+
229
+ def _seg_out_img(self, img, segmenter=None, mode='head'):
230
+ """
231
+ imgs [H, W, 3] 0-255
232
+ return : person_img [B, 3, H, W]
233
+ """
234
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
235
+ segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W]
236
+ return self._seg_out_img_with_segmap(img, segmap, mode=mode)
237
+
238
+ def seg_out_imgs(self, img, mode='head'):
239
+ """
240
+ api for pytorch img, -1~1
241
+ img: [B, 3, H, W], -1~1
242
+ """
243
+ device = img.device
244
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
245
+ img = ((img + 1) * 127.5).astype(np.uint8)
246
+ img_lst = [copy.deepcopy(img[i]) for i in range(len(img))]
247
+ out_lst = []
248
+ for im in img_lst:
249
+ out = self._seg_out_img(im, mode=mode)
250
+ out_lst.append(out)
251
+ seg_imgs = np.stack(out_lst) # [B, H, W, 3]
252
+ seg_imgs = (seg_imgs - 127.5) / 127.5
253
+ seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device)
254
+ return seg_imgs
255
+
256
+ if __name__ == '__main__':
257
+ import imageio, cv2, tqdm
258
+ import torchshow as ts
259
+ img = imageio.imread("1.png")
260
+ img = cv2.resize(img, (512,512))
261
+
262
+ seg_model = MediapipeSegmenter()
263
+ img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2)
264
+ img = (img-127.5)/127.5
265
+ out = seg_model.seg_out_imgs(img, 'torso')
266
+ ts.save(out,"torso.png")
267
+ out = seg_model.seg_out_imgs(img, 'head')
268
+ ts.save(out,"head.png")
269
+ out = seg_model.seg_out_imgs(img, 'bg')
270
+ ts.save(out,"bg.png")
271
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
272
+ img = ((img + 1) * 127.5).astype(np.uint8)
273
+ bg = extract_background(img)
274
+ ts.save(bg,"bg2.png")
data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
data_gen/utils/path_converter.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class PathConverter():
5
+ def __init__(self):
6
+ self.prefixs = {
7
+ "vid": "/video/",
8
+ "gt": "/gt_imgs/",
9
+ "head": "/head_imgs/",
10
+ "torso": "/torso_imgs/",
11
+ "person": "/person_imgs/",
12
+ "torso_with_bg": "/torso_with_bg_imgs/",
13
+ "single_bg": "/bg_img/",
14
+ "bg": "/bg_imgs/",
15
+ "segmaps": "/segmaps/",
16
+ "inpaint_torso": "/inpaint_torso_imgs/",
17
+ "com": "/com_imgs/",
18
+ "inpaint_torso_with_com_bg": "/inpaint_torso_with_com_bg_imgs/",
19
+ }
20
+
21
+ def to(self, path: str, old_pattern: str, new_pattern: str):
22
+ return path.replace(self.prefixs[old_pattern], self.prefixs[new_pattern], 1)
23
+
24
+ pc = PathConverter()
data_gen/utils/process_audio/extract_hubert.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, HubertModel
2
+ import soundfile as sf
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ from utils.commons.hparams import set_hparams, hparams
7
+
8
+
9
+ wav2vec2_processor = None
10
+ hubert_model = None
11
+
12
+
13
+ def get_hubert_from_16k_wav(wav_16k_name):
14
+ speech_16k, _ = sf.read(wav_16k_name)
15
+ hubert = get_hubert_from_16k_speech(speech_16k)
16
+ return hubert
17
+
18
+ @torch.no_grad()
19
+ def get_hubert_from_16k_speech(speech, device="cuda:0"):
20
+ global hubert_model, wav2vec2_processor
21
+ local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
22
+ if hubert_model is None:
23
+ print("Loading the HuBERT Model...")
24
+ if os.path.exists(local_path):
25
+ hubert_model = HubertModel.from_pretrained(local_path)
26
+ else:
27
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
28
+ hubert_model = hubert_model.to(device)
29
+ if wav2vec2_processor is None:
30
+ print("Loading the Wav2Vec2 Processor...")
31
+ if os.path.exists(local_path):
32
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
33
+ else:
34
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
35
+
36
+ if speech.ndim ==2:
37
+ speech = speech[:, 0] # [T, 2] ==> [T,]
38
+
39
+ input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
40
+ input_values_all = input_values_all.to(device)
41
+ # For long audio sequence, due to the memory limitation, we cannot process them in one run
42
+ # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
43
+ # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
44
+ # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
45
+ # We have the equation to calculate out time step: T = floor((t-k)/s)
46
+ # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
47
+ # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
48
+ kernel = 400
49
+ stride = 320
50
+ clip_length = stride * 1000
51
+ num_iter = input_values_all.shape[1] // clip_length
52
+ expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
53
+ res_lst = []
54
+ for i in range(num_iter):
55
+ if i == 0:
56
+ start_idx = 0
57
+ end_idx = clip_length - stride + kernel
58
+ else:
59
+ start_idx = clip_length * i
60
+ end_idx = start_idx + (clip_length - stride + kernel)
61
+ input_values = input_values_all[:, start_idx: end_idx]
62
+ hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
63
+ res_lst.append(hidden_states[0])
64
+ if num_iter > 0:
65
+ input_values = input_values_all[:, clip_length * num_iter:]
66
+ else:
67
+ input_values = input_values_all
68
+
69
+ if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
70
+ hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
71
+ res_lst.append(hidden_states[0])
72
+ ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
73
+
74
+ assert abs(ret.shape[0] - expected_T) <= 1
75
+ if ret.shape[0] < expected_T: # if skipping the last short
76
+ ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
77
+ else:
78
+ ret = ret[:expected_T]
79
+
80
+ return ret
81
+
82
+
83
+ if __name__ == '__main__':
84
+ from argparse import ArgumentParser
85
+ parser = ArgumentParser()
86
+ parser.add_argument('--video_id', type=str, default='May', help='')
87
+ args = parser.parse_args()
88
+ ### Process Single Long Audio for NeRF dataset
89
+ person_id = args.video_id
90
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
91
+ hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
92
+ speech_16k, _ = sf.read(wav_16k_name)
93
+ hubert_hidden = get_hubert_from_16k_speech(speech_16k)
94
+ np.save(hubert_npy_name, hubert_hidden.detach().numpy())
95
+ print(f"Saved at {hubert_npy_name}")
data_gen/utils/process_audio/extract_mel_f0.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import glob
4
+ import os
5
+ import tqdm
6
+ import librosa
7
+ import parselmouth
8
+ from utils.commons.pitch_utils import f0_to_coarse
9
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
10
+ from utils.commons.os_utils import multiprocess_glob
11
+ from utils.audio.io import save_wav
12
+
13
+ from moviepy.editor import VideoFileClip
14
+ from utils.commons.hparams import hparams, set_hparams
15
+
16
+ def resample_wav(wav_name, out_name, sr=16000):
17
+ wav_raw, sr = librosa.core.load(wav_name, sr=sr)
18
+ save_wav(wav_raw, out_name, sr)
19
+
20
+ def split_wav(mp4_name, wav_name=None):
21
+ if wav_name is None:
22
+ wav_name = mp4_name.replace(".mp4", ".wav").replace("/video/", "/audio/")
23
+ if os.path.exists(wav_name):
24
+ return wav_name
25
+ os.makedirs(os.path.dirname(wav_name), exist_ok=True)
26
+
27
+ video = VideoFileClip(mp4_name,verbose=False)
28
+ dur = video.duration
29
+ audio = video.audio
30
+ assert audio is not None
31
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
32
+ return wav_name
33
+
34
+ def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
35
+ '''compute right padding (final frame) or both sides padding (first and final frames)
36
+ '''
37
+ assert pad_sides in (1, 2)
38
+ # return int(fsize // 2)
39
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
40
+ if pad_sides == 1:
41
+ return 0, pad
42
+ else:
43
+ return pad // 2, pad // 2 + pad % 2
44
+
45
+ def extract_mel_from_fname(wav_path,
46
+ fft_size=512,
47
+ hop_size=320,
48
+ win_length=512,
49
+ window="hann",
50
+ num_mels=80,
51
+ fmin=80,
52
+ fmax=7600,
53
+ eps=1e-6,
54
+ sample_rate=16000,
55
+ min_level_db=-100):
56
+ if isinstance(wav_path, str):
57
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
58
+ else:
59
+ wav = wav_path
60
+
61
+ # get amplitude spectrogram
62
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
63
+ win_length=win_length, window=window, center=False)
64
+ spc = np.abs(x_stft) # (n_bins, T)
65
+
66
+ # get mel basis
67
+ fmin = 0 if fmin == -1 else fmin
68
+ fmax = sample_rate / 2 if fmax == -1 else fmax
69
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
70
+ mel = mel_basis @ spc
71
+
72
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
73
+ mel = mel.T
74
+
75
+ l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
76
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
77
+
78
+ return wav.T, mel
79
+
80
+ def extract_f0_from_wav_and_mel(wav, mel,
81
+ hop_size=320,
82
+ audio_sample_rate=16000,
83
+ ):
84
+ time_step = hop_size / audio_sample_rate * 1000
85
+ f0_min = 80
86
+ f0_max = 750
87
+ f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
88
+ time_step=time_step / 1000, voicing_threshold=0.6,
89
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
90
+
91
+ delta_l = len(mel) - len(f0)
92
+ assert np.abs(delta_l) <= 8
93
+ if delta_l > 0:
94
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
95
+ f0 = f0[:len(mel)]
96
+ pitch_coarse = f0_to_coarse(f0)
97
+ return f0, pitch_coarse
98
+
99
+
100
+ def extract_mel_f0_from_fname(wav_name=None, out_name=None):
101
+ try:
102
+ out_name = wav_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
103
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
104
+
105
+ wav, mel = extract_mel_from_fname(wav_name)
106
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
107
+ out_dict = {
108
+ "mel": mel, # [T, 80]
109
+ "f0": f0,
110
+ }
111
+ np.save(out_name, out_dict)
112
+ except Exception as e:
113
+ print(e)
114
+
115
+ def extract_mel_f0_from_video_name(mp4_name, wav_name=None, out_name=None):
116
+ if mp4_name.endswith(".mp4"):
117
+ wav_name = split_wav(mp4_name, wav_name)
118
+ if out_name is None:
119
+ out_name = mp4_name.replace(".mp4", "_mel_f0.npy").replace("/video/", "/mel_f0/")
120
+ elif mp4_name.endswith(".wav"):
121
+ wav_name = mp4_name
122
+ if out_name is None:
123
+ out_name = mp4_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
124
+
125
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
126
+
127
+ wav, mel = extract_mel_from_fname(wav_name)
128
+
129
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
130
+ out_dict = {
131
+ "mel": mel, # [T, 80]
132
+ "f0": f0,
133
+ }
134
+ np.save(out_name, out_dict)
135
+
136
+
137
+ if __name__ == '__main__':
138
+ from argparse import ArgumentParser
139
+ parser = ArgumentParser()
140
+ parser.add_argument('--video_id', type=str, default='May', help='')
141
+ args = parser.parse_args()
142
+ ### Process Single Long Audio for NeRF dataset
143
+ person_id = args.video_id
144
+
145
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
146
+ out_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy"
147
+ extract_mel_f0_from_video_name(wav_16k_name, out_name)
148
+ print(f"Saved at {out_name}")
data_gen/utils/process_audio/resample_audio_to_16k.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ from utils.commons.os_utils import multiprocess_glob
3
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
4
+
5
+
6
+ def extract_wav16k_job(audio_name:str):
7
+ out_path = audio_name.replace("/audio_raw/","/audio/",1)
8
+ assert out_path != audio_name # prevent inplace
9
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
10
+ ffmpeg_path = "/usr/bin/ffmpeg"
11
+
12
+ cmd = f'{ffmpeg_path} -i {audio_name} -ar 16000 -v quiet -y {out_path}'
13
+ os.system(cmd)
14
+
15
+ if __name__ == '__main__':
16
+ import argparse, glob, tqdm, random
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--aud_dir", default='/home/tiger/datasets/raw/CMLR/audio_raw/')
19
+ parser.add_argument("--ds_name", default='CMLR')
20
+ parser.add_argument("--num_workers", default=64, type=int)
21
+ parser.add_argument("--process_id", default=0, type=int)
22
+ parser.add_argument("--total_process", default=1, type=int)
23
+ args = parser.parse_args()
24
+ print(f"args {args}")
25
+
26
+ aud_dir = args.aud_dir
27
+ ds_name = args.ds_name
28
+ if ds_name in ['CMLR']:
29
+ aud_name_pattern = os.path.join(aud_dir, "*/*/*.wav")
30
+ aud_names = multiprocess_glob(aud_name_pattern)
31
+ else:
32
+ raise NotImplementedError()
33
+ aud_names = sorted(aud_names)
34
+ print(f"total audio number : {len(aud_names)}")
35
+ print(f"first {aud_names[0]} last {aud_names[-1]}")
36
+ # exit()
37
+ process_id = args.process_id
38
+ total_process = args.total_process
39
+ if total_process > 1:
40
+ assert process_id <= total_process -1
41
+ num_samples_per_process = len(aud_names) // total_process
42
+ if process_id == total_process:
43
+ aud_names = aud_names[process_id * num_samples_per_process : ]
44
+ else:
45
+ aud_names = aud_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
46
+
47
+ for i, res in multiprocess_run_tqdm(extract_wav16k_job, aud_names, num_workers=args.num_workers, desc="resampling videos"):
48
+ pass
49
+
data_gen/utils/process_image/extract_lm2d.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+ import sys
4
+
5
+ import glob
6
+ import cv2
7
+ import tqdm
8
+ import numpy as np
9
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
10
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ import random
15
+ random.seed(42)
16
+
17
+ import pickle
18
+ import json
19
+ import gzip
20
+ from typing import Any
21
+
22
+ def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
23
+ if is_json:
24
+ if is_gzip:
25
+ with gzip.open(filename, "r", encoding="utf-8") as f:
26
+ loaded_object = json.load(f)
27
+ return loaded_object
28
+ else:
29
+ with open(filename, "r", encoding="utf-8") as f:
30
+ loaded_object = json.load(f)
31
+ return loaded_object
32
+ else:
33
+ if is_gzip:
34
+ with gzip.open(filename, "rb") as f:
35
+ loaded_object = pickle.load(f)
36
+ return loaded_object
37
+ else:
38
+ with open(filename, "rb") as f:
39
+ loaded_object = pickle.load(f)
40
+ return loaded_object
41
+
42
+ def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
43
+ if is_json:
44
+ if is_gzip:
45
+ with gzip.open(filename, "w", encoding="utf-8") as f:
46
+ json.dump(content, f)
47
+ else:
48
+ with open(filename, "w", encoding="utf-8") as f:
49
+ json.dump(content, f)
50
+ else:
51
+ if is_gzip:
52
+ with gzip.open(filename, "wb") as f:
53
+ pickle.dump(content, f)
54
+ else:
55
+ with open(filename, "wb") as f:
56
+ pickle.dump(content, f)
57
+
58
+ face_landmarker = None
59
+
60
+ def extract_lms_mediapipe_job(img):
61
+ if img is None:
62
+ return None
63
+ global face_landmarker
64
+ if face_landmarker is None:
65
+ face_landmarker = MediapipeLandmarker()
66
+ lm478 = face_landmarker.extract_lm478_from_img(img)
67
+ return lm478
68
+
69
+ def extract_landmark_job(img_name):
70
+ try:
71
+ # if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
72
+ # print(1)
73
+ # input()
74
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
75
+ if os.path.exists(out_name):
76
+ print("out exists, skip...")
77
+ return
78
+ try:
79
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
80
+ except:
81
+ pass
82
+ img = cv2.imread(img_name)[:,:,::-1]
83
+
84
+ if img is not None:
85
+ lm468 = extract_lms_mediapipe_job(img)
86
+ if lm468 is not None:
87
+ np.save(out_name, lm468)
88
+ # print("Hahaha, solve one item!!!")
89
+ except Exception as e:
90
+ print(e)
91
+ pass
92
+
93
+ def out_exist_job(img_name):
94
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
95
+ if os.path.exists(out_name):
96
+ return None
97
+ else:
98
+ return img_name
99
+
100
+ # def get_todo_img_names(img_names):
101
+ # todo_img_names = []
102
+ # for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
103
+ # if res is not None:
104
+ # todo_img_names.append(res)
105
+ # return todo_img_names
106
+
107
+
108
+ if __name__ == '__main__':
109
+ import argparse, glob, tqdm, random
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
112
+ parser.add_argument("--ds_name", default='FFHQ')
113
+ parser.add_argument("--num_workers", default=64, type=int)
114
+ parser.add_argument("--process_id", default=0, type=int)
115
+ parser.add_argument("--total_process", default=1, type=int)
116
+ parser.add_argument("--reset", action='store_true')
117
+ parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
118
+ parser.add_argument("--load_img_names", action="store_true")
119
+
120
+ args = parser.parse_args()
121
+ print(f"args {args}")
122
+ img_dir = args.img_dir
123
+ img_names_file = os.path.join(img_dir, args.img_names_file)
124
+ if args.load_img_names:
125
+ img_names = load_file(img_names_file)
126
+ print(f"load image names from {img_names_file}")
127
+ else:
128
+ if args.ds_name == 'FFHQ_MV':
129
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
130
+ img_names1 = glob.glob(img_name_pattern1)
131
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
132
+ img_names2 = glob.glob(img_name_pattern2)
133
+ img_names = img_names1 + img_names2
134
+ img_names = sorted(img_names)
135
+ elif args.ds_name == 'FFHQ':
136
+ img_name_pattern = os.path.join(img_dir, "*.png")
137
+ img_names = glob.glob(img_name_pattern)
138
+ img_names = sorted(img_names)
139
+ elif args.ds_name == "PanoHeadGen":
140
+ # img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
141
+ img_name_patterns = ["ref/*/*.png"]
142
+ img_names = []
143
+ for img_name_pattern in img_name_patterns:
144
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
145
+ img_names_part = glob.glob(img_name_pattern_full)
146
+ img_names.extend(img_names_part)
147
+ img_names = sorted(img_names)
148
+
149
+ # save image names
150
+ if not args.load_img_names:
151
+ save_file(img_names_file, img_names)
152
+ print(f"save image names in {img_names_file}")
153
+
154
+ print(f"total images number: {len(img_names)}")
155
+
156
+
157
+ process_id = args.process_id
158
+ total_process = args.total_process
159
+ if total_process > 1:
160
+ assert process_id <= total_process -1
161
+ num_samples_per_process = len(img_names) // total_process
162
+ if process_id == total_process:
163
+ img_names = img_names[process_id * num_samples_per_process : ]
164
+ else:
165
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
166
+
167
+ # if not args.reset:
168
+ # img_names = get_todo_img_names(img_names)
169
+
170
+
171
+ print(f"todo_image {img_names[:10]}")
172
+ print(f"processing images number in this process: {len(img_names)}")
173
+ # print(f"todo images number: {len(img_names)}")
174
+ # input()
175
+ # exit()
176
+
177
+ if args.num_workers == 1:
178
+ index = 0
179
+ for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
180
+ try:
181
+ extract_landmark_job(img_name)
182
+ except Exception as e:
183
+ print(e)
184
+ pass
185
+ if index % max(1, int(len(img_names) * 0.003)) == 0:
186
+ print(f"processed {index} / {len(img_names)}")
187
+ sys.stdout.flush()
188
+ index += 1
189
+ else:
190
+ for i, res in multiprocess_run_tqdm(
191
+ extract_landmark_job, img_names,
192
+ num_workers=args.num_workers,
193
+ desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
194
+ # if index % max(1, int(len(img_names) * 0.003)) == 0:
195
+ print(f"processed {i+1} / {len(img_names)}")
196
+ sys.stdout.flush()
197
+ print(f"Root {args.process_id}: Finished extracting.")
data_gen/utils/process_image/extract_segment_imgs.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+
4
+ import glob
5
+ import cv2
6
+ import tqdm
7
+ import numpy as np
8
+ import PIL
9
+ from utils.commons.tensor_utils import convert_to_np
10
+ import torch
11
+ import mediapipe as mp
12
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
13
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
14
+ from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
15
+ seg_model = MediapipeSegmenter()
16
+
17
+
18
+ def extract_segment_job(img_name):
19
+ try:
20
+ img = cv2.imread(img_name)
21
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
22
+
23
+ segmap = seg_model._cal_seg_map(img)
24
+ bg_img = extract_background([img], [segmap])
25
+ out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
26
+ save_rgb_image_to_path(bg_img, out_img_name)
27
+
28
+ com_img = img.copy()
29
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
30
+ com_img[bg_part] = bg_img[bg_part]
31
+ out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
32
+ save_rgb_image_to_path(com_img, out_img_name)
33
+
34
+ for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
35
+ out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
36
+ out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
37
+ out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
38
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
39
+ except: pass
40
+ cv2.imwrite(out_img_name, out_img)
41
+
42
+ inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
43
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
44
+ save_rgb_image_to_path(inpaint_torso_img, out_img_name)
45
+ inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
46
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
47
+ save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
48
+ return 0
49
+ except Exception as e:
50
+ print(e)
51
+ return 1
52
+
53
+ def out_exist_job(img_name):
54
+ out_name1 = img_name.replace("/images_512/", "/head_imgs/")
55
+ out_name2 = img_name.replace("/images_512/", "/com_imgs/")
56
+ out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
57
+
58
+ if os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
59
+ return None
60
+ else:
61
+ return img_name
62
+
63
+ def get_todo_img_names(img_names):
64
+ todo_img_names = []
65
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
66
+ if res is not None:
67
+ todo_img_names.append(res)
68
+ return todo_img_names
69
+
70
+
71
+ if __name__ == '__main__':
72
+ import argparse, glob, tqdm, random
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("--img_dir", default='./images_512')
75
+ # parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
76
+ parser.add_argument("--ds_name", default='FFHQ')
77
+ parser.add_argument("--num_workers", default=1, type=int)
78
+ parser.add_argument("--seed", default=0, type=int)
79
+ parser.add_argument("--process_id", default=0, type=int)
80
+ parser.add_argument("--total_process", default=1, type=int)
81
+ parser.add_argument("--reset", action='store_true')
82
+
83
+ args = parser.parse_args()
84
+ img_dir = args.img_dir
85
+ if args.ds_name == 'FFHQ_MV':
86
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
87
+ img_names1 = glob.glob(img_name_pattern1)
88
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
89
+ img_names2 = glob.glob(img_name_pattern2)
90
+ img_names = img_names1 + img_names2
91
+ elif args.ds_name == 'FFHQ':
92
+ img_name_pattern = os.path.join(img_dir, "*.png")
93
+ img_names = glob.glob(img_name_pattern)
94
+
95
+ img_names = sorted(img_names)
96
+ random.seed(args.seed)
97
+ random.shuffle(img_names)
98
+
99
+ process_id = args.process_id
100
+ total_process = args.total_process
101
+ if total_process > 1:
102
+ assert process_id <= total_process -1
103
+ num_samples_per_process = len(img_names) // total_process
104
+ if process_id == total_process:
105
+ img_names = img_names[process_id * num_samples_per_process : ]
106
+ else:
107
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
108
+
109
+ if not args.reset:
110
+ img_names = get_todo_img_names(img_names)
111
+ print(f"todo images number: {len(img_names)}")
112
+
113
+ for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
114
+ pass
data_gen/utils/process_image/fit_3dmm_landmark.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy.core.numeric import require
2
+ from numpy.lib.function_base import quantile
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import copy
6
+ import numpy as np
7
+
8
+ import os
9
+ import sys
10
+ import cv2
11
+ import argparse
12
+ import tqdm
13
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
14
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
15
+
16
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
17
+ import pickle
18
+
19
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
20
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
21
+ face_model.to("cuda")
22
+
23
+
24
+ index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
25
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
26
+
27
+ dir_path = os.path.dirname(os.path.realpath(__file__))
28
+
29
+ LAMBDA_REG_ID = 0.3
30
+ LAMBDA_REG_EXP = 0.05
31
+
32
+ def save_file(name, content):
33
+ with open(name, "wb") as f:
34
+ pickle.dump(content, f)
35
+
36
+ def load_file(name):
37
+ with open(name, "rb") as f:
38
+ content = pickle.load(f)
39
+ return content
40
+
41
+ def cal_lan_loss_mp(proj_lan, gt_lan):
42
+ # [B, 68, 2]
43
+ loss = (proj_lan - gt_lan).pow(2)
44
+ # loss = (proj_lan - gt_lan).abs()
45
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
46
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
47
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
48
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
49
+ weights = torch.ones_like(loss)
50
+ weights[:, eye] = 5
51
+ weights[:, inner_lip] = 2
52
+ weights[:, outer_lip] = 2
53
+ weights[:, unmatch_mask] = 0
54
+ loss = loss * weights
55
+ return torch.mean(loss)
56
+
57
+ def cal_lan_loss(proj_lan, gt_lan):
58
+ # [B, 68, 2]
59
+ loss = (proj_lan - gt_lan)** 2
60
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
61
+ weights = torch.zeros_like(loss)
62
+ weights = torch.ones_like(loss)
63
+ weights[:, 36:48, :] = 3 # eye 12 points
64
+ weights[:, -8:, :] = 3 # inner lip 8 points
65
+ weights[:, 28:31, :] = 3 # nose 3 points
66
+ loss = loss * weights
67
+ return torch.mean(loss)
68
+
69
+ def set_requires_grad(tensor_list):
70
+ for tensor in tensor_list:
71
+ tensor.requires_grad = True
72
+
73
+ def read_video_to_frames(img_name):
74
+ frames = []
75
+ cap = cv2.VideoCapture(img_name)
76
+ while cap.isOpened():
77
+ ret, frame_bgr = cap.read()
78
+ if frame_bgr is None:
79
+ break
80
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
81
+ frames.append(frame_rgb)
82
+ return np.stack(frames)
83
+
84
+ @torch.enable_grad()
85
+ def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True):
86
+ img = cv2.imread(img_name)
87
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
88
+ img_h, img_w = img.shape[0], img.shape[0]
89
+ assert img_h == img_w
90
+ num_frames = 1
91
+
92
+ lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy")
93
+ if lm_name.endswith('_lms.npy') and os.path.exists(lm_name):
94
+ lms = np.load(lm_name)
95
+ else:
96
+ # print("lms_2d file not found, try to extract it from image...")
97
+ try:
98
+ landmarker = MediapipeLandmarker()
99
+ lms = landmarker.extract_lm478_from_img_name(img_name)
100
+ # lms = landmarker.extract_lm478_from_img(img)
101
+ except Exception as e:
102
+ print(e)
103
+ return
104
+ if lms is None:
105
+ print("get None lms_2d, please check whether each frame has one head, exiting...")
106
+ return
107
+ lms = lms[:468].reshape([468,2])
108
+ lms = torch.FloatTensor(lms).to(device=device)
109
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
110
+
111
+ if keypoint_mode == 'mediapipe':
112
+ cal_lan_loss_fn = cal_lan_loss_mp
113
+ out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy")
114
+ else:
115
+ cal_lan_loss_fn = cal_lan_loss
116
+ out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy")
117
+ try:
118
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
119
+ except:
120
+ pass
121
+
122
+ id_dim, exp_dim = 80, 64
123
+ sel_ids = np.arange(0, num_frames, 40)
124
+ sel_num = sel_ids.shape[0]
125
+ arg_focal = face_model.focal
126
+
127
+ h = w = face_model.center * 2
128
+ img_scale_factor = img_h / h
129
+ lms /= img_scale_factor
130
+ cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device)
131
+
132
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True)
133
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
134
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
135
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
136
+
137
+ focal_length = lms.new_zeros(1, requires_grad=True)
138
+ focal_length.data += arg_focal
139
+
140
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
141
+
142
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
143
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
144
+
145
+ # 其他参数初始化,先训练euler和trans
146
+ for _ in range(200):
147
+ proj_geo = face_model.compute_for_landmark_fit(
148
+ id_para, exp_para, euler_angle, trans)
149
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
150
+ loss = loss_lan
151
+ optimizer_frame.zero_grad()
152
+ loss.backward()
153
+ optimizer_frame.step()
154
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
155
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
156
+
157
+ for param_group in optimizer_frame.param_groups:
158
+ param_group['lr'] = 0.1
159
+
160
+ # "jointly roughly training id exp euler trans"
161
+ for _ in range(200):
162
+ proj_geo = face_model.compute_for_landmark_fit(
163
+ id_para, exp_para, euler_angle, trans)
164
+ loss_lan = cal_lan_loss_fn(
165
+ proj_geo[:, :, :2], lms.detach())
166
+ loss_regid = torch.mean(id_para*id_para) # 正则化
167
+ loss_regexp = torch.mean(exp_para * exp_para)
168
+
169
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
170
+ optimizer_idexp.zero_grad()
171
+ optimizer_frame.zero_grad()
172
+ loss.backward()
173
+ optimizer_idexp.step()
174
+ optimizer_frame.step()
175
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
176
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
177
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
178
+
179
+ # start fine training, intialize from the roughly trained results
180
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
181
+ id_para_.data = id_para.data.clone()
182
+ id_para = id_para_
183
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
184
+ exp_para_.data = exp_para.data.clone()
185
+ exp_para = exp_para_
186
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
187
+ euler_angle_.data = euler_angle.data.clone()
188
+ euler_angle = euler_angle_
189
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
190
+ trans_.data = trans.data.clone()
191
+ trans = trans_
192
+
193
+ batch_size = 1
194
+
195
+ # "fine fitting the 3DMM in batches"
196
+ for i in range(int((num_frames-1)/batch_size+1)):
197
+ if (i+1)*batch_size > num_frames:
198
+ start_n = num_frames-batch_size
199
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
200
+ else:
201
+ start_n = i*batch_size
202
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
203
+ sel_lms = lms[sel_ids]
204
+
205
+ sel_id_para = id_para.new_zeros(
206
+ (batch_size, id_dim), requires_grad=True)
207
+ sel_id_para.data = id_para[sel_ids].clone()
208
+ sel_exp_para = exp_para.new_zeros(
209
+ (batch_size, exp_dim), requires_grad=True)
210
+ sel_exp_para.data = exp_para[sel_ids].clone()
211
+ sel_euler_angle = euler_angle.new_zeros(
212
+ (batch_size, 3), requires_grad=True)
213
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
214
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
215
+ sel_trans.data = trans[sel_ids].clone()
216
+
217
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
218
+ optimizer_cur_batch = torch.optim.Adam(
219
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
220
+
221
+ for j in range(50):
222
+ proj_geo = face_model.compute_for_landmark_fit(
223
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans)
224
+ loss_lan = cal_lan_loss_fn(
225
+ proj_geo[:, :, :2], lms.unsqueeze(0).detach())
226
+
227
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
228
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
229
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
230
+ optimizer_cur_batch.zero_grad()
231
+ loss.backward()
232
+ optimizer_cur_batch.step()
233
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}")
234
+ id_para[sel_ids].data = sel_id_para.data.clone()
235
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
236
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
237
+ trans[sel_ids].data = sel_trans.data.clone()
238
+
239
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
240
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
241
+ if save:
242
+ np.save(out_name, coeff_dict, allow_pickle=True)
243
+
244
+ if debug:
245
+ import imageio
246
+ debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg")
247
+ try: os.makedirs(os.path.dirname(debug_name), exist_ok=True)
248
+ except: pass
249
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
250
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
251
+ lm68s = lm68s * img_scale_factor
252
+ lms = lms * img_scale_factor
253
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
254
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
255
+ lm68s = lm68s.astype(int)
256
+ lm68s = lm68s.reshape([-1,2])
257
+ lms = lms.cpu().numpy().astype(int).reshape([-1,2])
258
+ for lm in lm68s:
259
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1)
260
+ for gt_lm in lms:
261
+ img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1)
262
+ imageio.imwrite(debug_name, img)
263
+ print(f"debug img saved at {debug_name}")
264
+ return coeff_dict
265
+
266
+ def out_exist_job(vid_name):
267
+ out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy")
268
+ # if os.path.exists(out_name) or not os.path.exists(lms_name):
269
+ if os.path.exists(out_name):
270
+ return None
271
+ else:
272
+ return vid_name
273
+
274
+ def get_todo_img_names(img_names):
275
+ todo_img_names = []
276
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16):
277
+ if res is not None:
278
+ todo_img_names.append(res)
279
+ return todo_img_names
280
+
281
+
282
+ if __name__ == '__main__':
283
+ import argparse, glob, tqdm
284
+ parser = argparse.ArgumentParser()
285
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
286
+ parser.add_argument("--ds_name", default='FFHQ')
287
+ parser.add_argument("--seed", default=0, type=int)
288
+ parser.add_argument("--process_id", default=0, type=int)
289
+ parser.add_argument("--total_process", default=1, type=int)
290
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
291
+ parser.add_argument("--debug", action='store_true')
292
+ parser.add_argument("--reset", action='store_true')
293
+ parser.add_argument("--device", default="cuda:0", type=str)
294
+ parser.add_argument("--output_log", action='store_true')
295
+ parser.add_argument("--load_names", action="store_true")
296
+
297
+ args = parser.parse_args()
298
+ img_dir = args.img_dir
299
+ load_names = args.load_names
300
+
301
+ print(f"args {args}")
302
+
303
+ if args.ds_name == 'single_img':
304
+ img_names = [img_dir]
305
+ else:
306
+ img_names_path = os.path.join(img_dir, "img_dir.pkl")
307
+ if os.path.exists(img_names_path) and load_names:
308
+ print(f"loading vid names from {img_names_path}")
309
+ img_names = load_file(img_names_path)
310
+ else:
311
+ if args.ds_name == 'FFHQ_MV':
312
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
313
+ img_names1 = glob.glob(img_name_pattern1)
314
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
315
+ img_names2 = glob.glob(img_name_pattern2)
316
+ img_names = img_names1 + img_names2
317
+ img_names = sorted(img_names)
318
+ elif args.ds_name == 'FFHQ':
319
+ img_name_pattern = os.path.join(img_dir, "*.png")
320
+ img_names = glob.glob(img_name_pattern)
321
+ img_names = sorted(img_names)
322
+ elif args.ds_name == "PanoHeadGen":
323
+ img_name_patterns = ["ref/*/*.png"]
324
+ img_names = []
325
+ for img_name_pattern in img_name_patterns:
326
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
327
+ img_names_part = glob.glob(img_name_pattern_full)
328
+ img_names.extend(img_names_part)
329
+ img_names = sorted(img_names)
330
+ print(f"saving image names to {img_names_path}")
331
+ save_file(img_names_path, img_names)
332
+
333
+ # import random
334
+ # random.seed(args.seed)
335
+ # random.shuffle(img_names)
336
+
337
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
338
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
339
+ face_model.to(torch.device(args.device))
340
+
341
+ process_id = args.process_id
342
+ total_process = args.total_process
343
+ if total_process > 1:
344
+ assert process_id <= total_process -1 and process_id >= 0
345
+ num_samples_per_process = len(img_names) // total_process
346
+ if process_id == total_process:
347
+ img_names = img_names[process_id * num_samples_per_process : ]
348
+ else:
349
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
350
+ print(f"image names number (before fileter): {len(img_names)}")
351
+
352
+
353
+ if not args.reset:
354
+ img_names = get_todo_img_names(img_names)
355
+
356
+ print(f"image names number (after fileter): {len(img_names)}")
357
+ for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."):
358
+ img_name = img_names[i]
359
+ try:
360
+ fit_3dmm_for_a_image(img_name, args.debug, device=args.device)
361
+ except Exception as e:
362
+ print(img_name, e)
363
+ if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0:
364
+ print(f"process {process_id}: {i + 1} / {len(img_names)} done")
365
+ sys.stdout.flush()
366
+ sys.stderr.flush()
367
+
368
+ print(f"process {process_id}: fitting 3dmm all done")
369
+
data_gen/utils/process_video/euler2quaterion.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ import numba
5
+ from scipy.spatial.transform import Rotation as R
6
+
7
+ def euler2quaterion(euler, use_radian=True):
8
+ """
9
+ euler: np.array, [batch, 3]
10
+ return: the quaterion, np.array, [batch, 4]
11
+ """
12
+ r = R.from_euler('xyz',euler, degrees=not use_radian)
13
+ return r.as_quat()
14
+
15
+ def quaterion2euler(quat, use_radian=True):
16
+ """
17
+ quat: np.array, [batch, 4]
18
+ return: the euler, np.array, [batch, 3]
19
+ """
20
+ r = R.from_quat(quat)
21
+ return r.as_euler('xyz', degrees=not use_radian)
22
+
23
+ def rot2quaterion(rot):
24
+ r = R.from_matrix(rot)
25
+ return r.as_quat()
26
+
27
+ def quaterion2rot(quat):
28
+ r = R.from_quat(quat)
29
+ return r.as_matrix()
30
+
31
+ if __name__ == '__main__':
32
+ euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3])
33
+ q = euler2quaterion(euler, use_radian=False)
34
+ e = quaterion2euler(q, use_radian=False)
35
+ print(" ")
data_gen/utils/process_video/extract_blink.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from data_util.face3d_helper import Face3DHelper
3
+ from utils.commons.tensor_utils import convert_to_tensor
4
+
5
+ def polygon_area(x, y):
6
+ """
7
+ x: [T, K=6]
8
+ y: [T, K=6]
9
+ return: [T,]
10
+ """
11
+ x_ = x - x.mean(axis=-1, keepdims=True)
12
+ y_ = y - y.mean(axis=-1, keepdims=True)
13
+ correction = x_[:,-1] * y_[:,0] - y_[:,-1]* x_[:,0]
14
+ main_area = (x_[:,:-1] * y_[:,1:]).sum(axis=-1) - (y_[:,:-1] * x_[:,1:]).sum(axis=-1)
15
+ return 0.5 * np.abs(main_area + correction)
16
+
17
+ def get_eye_area_percent(id, exp, face3d_helper):
18
+ id = convert_to_tensor(id)
19
+ exp = convert_to_tensor(exp)
20
+ cano_lm3d = face3d_helper.reconstruct_cano_lm3d(id, exp)
21
+ cano_lm2d = (cano_lm3d[..., :2] + 1) / 2
22
+ lms = cano_lm2d.cpu().numpy()
23
+ eyes_left = slice(36, 42)
24
+ eyes_right = slice(42, 48)
25
+ area_left = polygon_area(lms[:, eyes_left, 0], lms[:, eyes_left, 1])
26
+ area_right = polygon_area(lms[:, eyes_right, 0], lms[:, eyes_right, 1])
27
+ # area percentage of two eyes of the whole image...
28
+ area_percent = (area_left + area_right) / 1 * 100 # recommend threshold is 0.25%
29
+ return area_percent # [T,]
30
+
31
+
32
+ if __name__ == '__main__':
33
+ import numpy as np
34
+ import imageio
35
+ import cv2
36
+ import torch
37
+ from data_gen.utils.process_video.extract_lm2d import extract_lms_mediapipe_job, read_video_to_frames, index_lm68_from_lm468
38
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
39
+ from data_util.face3d_helper import Face3DHelper
40
+
41
+ face3d_helper = Face3DHelper()
42
+ video_name = 'data/raw/videos/May_10s.mp4'
43
+ frames = read_video_to_frames(video_name)
44
+ coeff = fit_3dmm_for_a_video(video_name, save=False)
45
+ area_percent = get_eye_area_percent(torch.tensor(coeff['id']), torch.tensor(coeff['exp']), face3d_helper)
46
+ writer = imageio.get_writer("1.mp4", fps=25)
47
+ for idx, frame in enumerate(frames):
48
+ frame = cv2.putText(frame, f"{area_percent[idx]:.2f}", org=(128,128), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=(255,0,0), thickness=1)
49
+ writer.append_data(frame)
50
+ writer.close()
data_gen/utils/process_video/extract_lm2d.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+ import sys
4
+ import glob
5
+ import cv2
6
+ import pickle
7
+ import tqdm
8
+ import numpy as np
9
+ import mediapipe as mp
10
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
11
+ from utils.commons.os_utils import multiprocess_glob
12
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
13
+ import warnings
14
+ import traceback
15
+
16
+ warnings.filterwarnings('ignore')
17
+
18
+ """
19
+ 基于Face_aligment的lm68已被弃用,因为其:
20
+ 1. 对眼睛部位的预测精度极低
21
+ 2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
22
+ 我们目前转而使用基于mediapipe的lm68
23
+ """
24
+ # def extract_landmarks(ori_imgs_dir):
25
+
26
+ # print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
27
+
28
+ # fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
29
+ # image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
30
+ # for image_path in tqdm.tqdm(image_paths):
31
+ # out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
32
+ # if os.path.exists(out_name):
33
+ # continue
34
+ # input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
35
+ # input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
36
+ # preds = fa.get_landmarks(input)
37
+ # if preds is None:
38
+ # print(f"Skip {image_path} for no face detected")
39
+ # continue
40
+ # if len(preds) > 0:
41
+ # lands = preds[0].reshape(-1, 2)[:,:2]
42
+ # os.makedirs(os.path.dirname(out_name), exist_ok=True)
43
+ # np.savetxt(out_name, lands, '%f')
44
+ # del fa
45
+ # print(f'[INFO] ===== extracted face landmarks =====')
46
+
47
+ def save_file(name, content):
48
+ with open(name, "wb") as f:
49
+ pickle.dump(content, f)
50
+
51
+ def load_file(name):
52
+ with open(name, "rb") as f:
53
+ content = pickle.load(f)
54
+ return content
55
+
56
+
57
+ face_landmarker = None
58
+
59
+ def extract_landmark_job(video_name, nerf=False):
60
+ try:
61
+ if nerf:
62
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
63
+ else:
64
+ out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
65
+ if os.path.exists(out_name):
66
+ # print("out exists, skip...")
67
+ return
68
+ try:
69
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
70
+ except:
71
+ pass
72
+ global face_landmarker
73
+ if face_landmarker is None:
74
+ face_landmarker = MediapipeLandmarker()
75
+ img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
76
+ lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
77
+ np.save(out_name, lm478)
78
+ return True
79
+ # print("Hahaha, solve one item!!!")
80
+ except Exception as e:
81
+ traceback.print_exc()
82
+ return False
83
+
84
+ def out_exist_job(vid_name):
85
+ out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
86
+ if os.path.exists(out_name):
87
+ return None
88
+ else:
89
+ return vid_name
90
+
91
+ def get_todo_vid_names(vid_names):
92
+ if len(vid_names) == 1: # nerf
93
+ return vid_names
94
+ todo_vid_names = []
95
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
96
+ if res is not None:
97
+ todo_vid_names.append(res)
98
+ return todo_vid_names
99
+
100
+ if __name__ == '__main__':
101
+ import argparse, glob, tqdm, random
102
+ parser = argparse.ArgumentParser()
103
+ parser.add_argument("--vid_dir", default='nerf')
104
+ parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
105
+ parser.add_argument("--num_workers", default=2, type=int)
106
+ parser.add_argument("--process_id", default=0, type=int)
107
+ parser.add_argument("--total_process", default=1, type=int)
108
+ parser.add_argument("--reset", action="store_true")
109
+ parser.add_argument("--load_names", action="store_true")
110
+
111
+ args = parser.parse_args()
112
+ vid_dir = args.vid_dir
113
+ ds_name = args.ds_name
114
+ load_names = args.load_names
115
+
116
+ if ds_name.lower() == 'nerf': # 处理单个视频
117
+ vid_names = [vid_dir]
118
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
119
+ else: # 处理整个数据集
120
+ if ds_name in ['lrs3_trainval']:
121
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
122
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
123
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
124
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
125
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
126
+ elif ds_name in ["RAVDESS", 'VFHQ']:
127
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
128
+ else:
129
+ raise NotImplementedError()
130
+
131
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
132
+ if os.path.exists(vid_names_path) and load_names:
133
+ print(f"loading vid names from {vid_names_path}")
134
+ vid_names = load_file(vid_names_path)
135
+ else:
136
+ vid_names = multiprocess_glob(vid_name_pattern)
137
+ vid_names = sorted(vid_names)
138
+ if not load_names:
139
+ print(f"saving vid names to {vid_names_path}")
140
+ save_file(vid_names_path, vid_names)
141
+ out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]
142
+
143
+ process_id = args.process_id
144
+ total_process = args.total_process
145
+ if total_process > 1:
146
+ assert process_id <= total_process -1
147
+ num_samples_per_process = len(vid_names) // total_process
148
+ if process_id == total_process:
149
+ vid_names = vid_names[process_id * num_samples_per_process : ]
150
+ else:
151
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
152
+
153
+ if not args.reset:
154
+ vid_names = get_todo_vid_names(vid_names)
155
+ print(f"todo videos number: {len(vid_names)}")
156
+
157
+ fail_cnt = 0
158
+ job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
159
+ for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
160
+ if res is False:
161
+ fail_cnt += 1
162
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
163
+ sys.stdout.flush()
164
+ pass
data_gen/utils/process_video/extract_segment_imgs.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+ import random
4
+ import glob
5
+ import cv2
6
+ import tqdm
7
+ import numpy as np
8
+ import PIL
9
+ from utils.commons.tensor_utils import convert_to_np
10
+ from utils.commons.os_utils import multiprocess_glob
11
+ import pickle
12
+ import torch
13
+ import mediapipe as mp
14
+ import traceback
15
+ import multiprocessing
16
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
17
+ from scipy.ndimage import binary_erosion, binary_dilation
18
+ from sklearn.neighbors import NearestNeighbors
19
+ from mediapipe.tasks.python import vision
20
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter, encode_segmap_mask_to_image, decode_segmap_mask_from_image
21
+
22
+ seg_model = None
23
+ segmenter = None
24
+ mat_model = None
25
+ lama_model = None
26
+ lama_config = None
27
+
28
+ from data_gen.utils.process_video.split_video_to_imgs import extract_img_job
29
+
30
+ BG_NAME_MAP = {
31
+ "knn": "",
32
+ "mat": "_mat",
33
+ "ddnm": "_ddnm",
34
+ "lama": "_lama",
35
+ }
36
+ FRAME_SELECT_INTERVAL = 5
37
+ SIM_METHOD = "mse"
38
+ SIM_THRESHOLD = 3
39
+
40
+ def save_file(name, content):
41
+ with open(name, "wb") as f:
42
+ pickle.dump(content, f)
43
+
44
+ def load_file(name):
45
+ with open(name, "rb") as f:
46
+ content = pickle.load(f)
47
+ return content
48
+
49
+ def save_rgb_alpha_image_to_path(img, alpha, img_path):
50
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
51
+ except: pass
52
+ cv2.imwrite(img_path, np.concatenate([cv2.cvtColor(img, cv2.COLOR_RGB2BGR), alpha], axis=-1))
53
+
54
+ def save_rgb_image_to_path(img, img_path):
55
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
56
+ except: pass
57
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
58
+
59
+ def load_rgb_image_to_path(img_path):
60
+ return cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
61
+
62
+ def image_similarity(x: np.ndarray, y: np.ndarray, method="mse"):
63
+ if method == "mse":
64
+ return np.mean((x - y) ** 2)
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ def extract_background(img_lst, segmap_mask_lst=None, method="knn", device='cpu', mix_bg=True):
69
+ """
70
+ img_lst: list of rgb ndarray
71
+ method: "knn", "mat" or "ddnm"
72
+ """
73
+ # only use 1/20 images
74
+ global segmenter
75
+ global seg_model
76
+ global mat_model
77
+ global lama_model
78
+ global lama_config
79
+
80
+ assert len(img_lst) > 0
81
+ if segmap_mask_lst is not None:
82
+ assert len(segmap_mask_lst) == len(img_lst)
83
+ else:
84
+ del segmenter
85
+ del seg_model
86
+ seg_model = MediapipeSegmenter()
87
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
88
+
89
+ def get_segmap_mask(img_lst, segmap_mask_lst, index):
90
+ if segmap_mask_lst is not None:
91
+ segmap = segmap_mask_lst[index]
92
+ else:
93
+ segmap = seg_model._cal_seg_map(img_lst[index], segmenter=segmenter)
94
+ return segmap
95
+
96
+ if method == "knn":
97
+ num_frames = len(img_lst)
98
+ img_lst = img_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else img_lst[0:1]
99
+
100
+ if segmap_mask_lst is not None:
101
+ segmap_mask_lst = segmap_mask_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else segmap_mask_lst[0:1]
102
+ assert len(img_lst) == len(segmap_mask_lst)
103
+ # get H/W
104
+ h, w = img_lst[0].shape[:2]
105
+
106
+ # nearest neighbors
107
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() # [512*512, 2] coordinate grid
108
+ distss = []
109
+ for idx, img in enumerate(img_lst):
110
+ segmap = get_segmap_mask(img_lst=img_lst, segmap_mask_lst=segmap_mask_lst, index=idx)
111
+ bg = (segmap[0]).astype(bool) # [h,w] bool mask
112
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) # [N_nonbg,2] coordinate of non-bg pixels
113
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
114
+ dists, _ = nbrs.kneighbors(all_xys) # [512*512, 1] distance to nearest non-bg pixel
115
+ distss.append(dists)
116
+
117
+ distss = np.stack(distss) # [B, 512*512, 1]
118
+ max_dist = np.max(distss, 0) # [512*512, 1]
119
+ max_id = np.argmax(distss, 0) # id of frame
120
+
121
+ bc_pixs = max_dist > 10 # 在各个frame有一个出现过是bg的pixel,bg标准是离最近的non-bg pixel距离大于10
122
+ bc_pixs_id = np.nonzero(bc_pixs)
123
+ bc_ids = max_id[bc_pixs]
124
+
125
+ num_pixs = distss.shape[1]
126
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
127
+
128
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
129
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] # 对那些铁bg的pixel,直接去对应的image里面采样
130
+ bg_img = bg_img.reshape(h, w, 3)
131
+
132
+ max_dist = max_dist.reshape(h, w)
133
+ bc_pixs = max_dist > 10 # 5
134
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
135
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
136
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
137
+ distances, indices = nbrs.kneighbors(bg_xys) # 对non-bg img,���KNN找最近的bg pixel
138
+ bg_fg_xys = fg_xys[indices[:, 0]]
139
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
140
+ else:
141
+ raise NotImplementedError # deperated
142
+
143
+ return bg_img
144
+
145
+ def inpaint_torso_job(gt_img, segmap):
146
+ bg_part = (segmap[0]).astype(bool)
147
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(bool)
148
+ neck_part = (segmap[2]).astype(bool)
149
+ torso_part = (segmap[4]).astype(bool)
150
+ img = gt_img.copy()
151
+ img[head_part] = 0
152
+
153
+ # torso part "vertical" in-painting...
154
+ L = 8 + 1
155
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
156
+ # lexsort: sort 2D coords first by y then by x,
157
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
158
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
159
+ torso_coords = torso_coords[inds]
160
+ # choose the top pixel for each column
161
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
162
+ top_torso_coords = torso_coords[uid] # [m, 2]
163
+ # only keep top-is-head pixels
164
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
165
+ mask = head_part[tuple(top_torso_coords_up.T)]
166
+ if mask.any():
167
+ top_torso_coords = top_torso_coords[mask]
168
+ # get the color
169
+ top_torso_colors = gt_img[tuple(top_torso_coords.T)] # [m, 3]
170
+ # construct inpaint coords (vertically up, or minus in x)
171
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
172
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
173
+ inpaint_torso_coords += inpaint_offsets
174
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
175
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
176
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
177
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
178
+ # set color
179
+ img[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
180
+ inpaint_torso_mask = np.zeros_like(img[..., 0]).astype(bool)
181
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
182
+ else:
183
+ inpaint_torso_mask = None
184
+
185
+ # neck part "vertical" in-painting...
186
+ push_down = 4
187
+ L = 48 + push_down + 1
188
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
189
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
190
+ # lexsort: sort 2D coords first by y then by x,
191
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
192
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
193
+ neck_coords = neck_coords[inds]
194
+ # choose the top pixel for each column
195
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
196
+ top_neck_coords = neck_coords[uid] # [m, 2]
197
+ # only keep top-is-head pixels
198
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
199
+ mask = head_part[tuple(top_neck_coords_up.T)]
200
+ top_neck_coords = top_neck_coords[mask]
201
+ # push these top down for 4 pixels to make the neck inpainting more natural...
202
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
203
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
204
+ # get the color
205
+ top_neck_colors = gt_img[tuple(top_neck_coords.T)] # [m, 3]
206
+ # construct inpaint coords (vertically up, or minus in x)
207
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
208
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
209
+ inpaint_neck_coords += inpaint_offsets
210
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
211
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
212
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
213
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
214
+ # set color
215
+ img[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
216
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
217
+ inpaint_mask = np.zeros_like(img[..., 0]).astype(bool)
218
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
219
+
220
+ blur_img = img.copy()
221
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
222
+ img[inpaint_mask] = blur_img[inpaint_mask]
223
+
224
+ # set mask
225
+ torso_img_mask = (neck_part | torso_part | inpaint_mask)
226
+ torso_with_bg_img_mask = (bg_part | neck_part | torso_part | inpaint_mask)
227
+ if inpaint_torso_mask is not None:
228
+ torso_img_mask = torso_img_mask | inpaint_torso_mask
229
+ torso_with_bg_img_mask = torso_with_bg_img_mask | inpaint_torso_mask
230
+
231
+ torso_img = img.copy()
232
+ torso_img[~torso_img_mask] = 0
233
+ torso_with_bg_img = img.copy()
234
+ torso_img[~torso_with_bg_img_mask] = 0
235
+
236
+ return torso_img, torso_img_mask, torso_with_bg_img, torso_with_bg_img_mask
237
+
238
+
239
+ def extract_segment_job(video_name, nerf=False, idx=None, total=None, background_method='knn', device="cpu", total_gpus=0, mix_bg=True):
240
+ global segmenter
241
+ global seg_model
242
+ del segmenter
243
+ del seg_model
244
+ seg_model = MediapipeSegmenter()
245
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
246
+ try:
247
+ if "cuda" in device:
248
+ # determine which cuda index from subprocess id
249
+ pname = multiprocessing.current_process().name
250
+ pid = int(pname.rsplit("-", 1)[-1]) - 1
251
+ cuda_id = pid % total_gpus
252
+ device = f"cuda:{cuda_id}"
253
+
254
+ if nerf: # single video
255
+ raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
256
+ else: # whole dataset
257
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
258
+ if not os.path.exists(raw_img_dir):
259
+ extract_img_job(video_name, raw_img_dir) # use ffmpeg to split video into imgs
260
+
261
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
262
+
263
+ img_lst = []
264
+
265
+ for img_name in img_names:
266
+ img = cv2.imread(img_name)
267
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
268
+ img_lst.append(img)
269
+
270
+ segmap_mask_lst, segmap_image_lst = seg_model._cal_seg_map_for_video(img_lst, segmenter=segmenter, return_onehot_mask=True, return_segmap_image=True)
271
+ del segmap_image_lst
272
+ # for i in range(len(img_lst)):
273
+ for i in tqdm.trange(len(img_lst), desc='generating segment images using segmaps...'):
274
+ img_name = img_names[i]
275
+ segmap = segmap_mask_lst[i]
276
+ img = img_lst[i]
277
+ out_img_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png") # 存成jpg的话,pixel value会有误差
278
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
279
+ except: pass
280
+ encoded_segmap = encode_segmap_mask_to_image(segmap)
281
+ save_rgb_image_to_path(encoded_segmap, out_img_name)
282
+
283
+ for mode in ['head', 'torso', 'person', 'bg']:
284
+ out_img, mask = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
285
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
286
+ mask = mask[0][..., None]
287
+ img_alpha[~mask] = 0
288
+ out_img_name = img_name.replace("/gt_imgs/", f"/{mode}_imgs/").replace(".jpg", ".png")
289
+ save_rgb_alpha_image_to_path(out_img, img_alpha, out_img_name)
290
+
291
+ inpaint_torso_img, inpaint_torso_img_mask, inpaint_torso_with_bg_img, inpaint_torso_with_bg_img_mask = inpaint_torso_job(img, segmap)
292
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
293
+ img_alpha[~inpaint_torso_img_mask[..., None]] = 0
294
+ out_img_name = img_name.replace("/gt_imgs/", f"/inpaint_torso_imgs/").replace(".jpg", ".png")
295
+ save_rgb_alpha_image_to_path(inpaint_torso_img, img_alpha, out_img_name)
296
+
297
+ bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
298
+ bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
299
+ if nerf:
300
+ out_img_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
301
+ else:
302
+ out_img_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
303
+ save_rgb_image_to_path(bg_img, out_img_name)
304
+
305
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
306
+ for i, img_name in enumerate(img_names):
307
+ com_img = img_lst[i].copy()
308
+ segmap = segmap_mask_lst[i]
309
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
310
+ com_img[bg_part] = bg_img[bg_part]
311
+ out_img_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
312
+ save_rgb_image_to_path(com_img, out_img_name)
313
+ return 0
314
+ except Exception as e:
315
+ print(str(type(e)), e)
316
+ traceback.print_exc(e)
317
+ return 1
318
+
319
+ # def check_bg_img_job_finished(raw_img_dir, bg_name, com_dir):
320
+ # img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
321
+ # com_names = glob.glob(os.path.join(com_dir, "*.jpg"))
322
+ # return len(img_names) == len(com_names) and os.path.exists(bg_name)
323
+
324
+ # extract background and combined image
325
+ # need pre-processed "gt_imgs" and "segmaps"
326
+ def extract_bg_img_job(video_name, nerf=False, idx=None, total=None, background_method='knn', device="cpu", total_gpus=0, mix_bg=True):
327
+ try:
328
+ bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
329
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
330
+
331
+ if "cuda" in device:
332
+ # determine which cuda index from subprocess id
333
+ pname = multiprocessing.current_process().name
334
+ pid = int(pname.rsplit("-", 1)[-1]) - 1
335
+ cuda_id = pid % total_gpus
336
+ device = f"cuda:{cuda_id}"
337
+
338
+ if nerf: # single video
339
+ raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
340
+ else: # whole dataset
341
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
342
+ if nerf:
343
+ bg_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
344
+ else:
345
+ bg_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
346
+ # com_dir = raw_img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
347
+ # if check_bg_img_job_finished(raw_img_dir=raw_img_dir, bg_name=bg_name, com_dir=com_dir):
348
+ # print(f"Already finished, skip {raw_img_dir} ")
349
+ # return 0
350
+
351
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
352
+ img_lst = []
353
+ for img_name in img_names:
354
+ img = cv2.imread(img_name)
355
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
356
+ img_lst.append(img)
357
+
358
+ segmap_mask_lst = []
359
+ for img_name in img_names:
360
+ segmap_img_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png")
361
+ segmap_img = load_rgb_image_to_path(segmap_img_name)
362
+
363
+ segmap_mask = decode_segmap_mask_from_image(segmap_img)
364
+ segmap_mask_lst.append(segmap_mask)
365
+
366
+ bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
367
+ save_rgb_image_to_path(bg_img, bg_name)
368
+
369
+ for i, img_name in enumerate(img_names):
370
+ com_img = img_lst[i].copy()
371
+ segmap = segmap_mask_lst[i]
372
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3, axis=-1)
373
+ com_img[bg_part] = bg_img[bg_part]
374
+ com_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
375
+ save_rgb_image_to_path(com_img, com_name)
376
+ return 0
377
+
378
+ except Exception as e:
379
+ print(str(type(e)), e)
380
+ traceback.print_exc(e)
381
+ return 1
382
+
383
+ def out_exist_job(vid_name, background_method='knn', only_bg_img=False):
384
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
385
+ img_dir = vid_name.replace("/video/", "/gt_imgs/").replace(".mp4", "")
386
+ out_dir1 = img_dir.replace("/gt_imgs/", "/head_imgs/")
387
+ out_dir2 = img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
388
+
389
+ if not only_bg_img:
390
+ if os.path.exists(img_dir) and os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) :
391
+ num_frames = len(os.listdir(img_dir))
392
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames:
393
+ return None
394
+ else:
395
+ return vid_name
396
+ else:
397
+ return vid_name
398
+ else:
399
+ if os.path.exists(img_dir) and os.path.exists(out_dir2):
400
+ num_frames = len(os.listdir(img_dir))
401
+ if len(os.listdir(out_dir2)) == num_frames:
402
+ return None
403
+ else:
404
+ return vid_name
405
+ else:
406
+ return vid_name
407
+
408
+ def get_todo_vid_names(vid_names, background_method='knn', only_bg_img=False):
409
+ if len(vid_names) == 1: # nerf
410
+ return vid_names
411
+ todo_vid_names = []
412
+ fn_args = [(vid_name, background_method, only_bg_img) for vid_name in vid_names]
413
+ for i, res in multiprocess_run_tqdm(out_exist_job, fn_args, num_workers=16, desc="checking todo videos..."):
414
+ if res is not None:
415
+ todo_vid_names.append(res)
416
+ return todo_vid_names
417
+
418
+ if __name__ == '__main__':
419
+ import argparse, glob, tqdm, random
420
+ parser = argparse.ArgumentParser()
421
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
422
+ parser.add_argument("--ds_name", default='CelebV-HQ')
423
+ parser.add_argument("--num_workers", default=48, type=int)
424
+ parser.add_argument("--seed", default=0, type=int)
425
+ parser.add_argument("--process_id", default=0, type=int)
426
+ parser.add_argument("--total_process", default=1, type=int)
427
+ parser.add_argument("--reset", action='store_true')
428
+ parser.add_argument("--load_names", action="store_true")
429
+ parser.add_argument("--background_method", choices=['knn', 'mat', 'ddnm', 'lama'], type=str, default='knn')
430
+ parser.add_argument("--total_gpus", default=0, type=int) # zero gpus means utilizing cpu
431
+ parser.add_argument("--only_bg_img", action="store_true")
432
+ parser.add_argument("--no_mix_bg", action="store_true")
433
+
434
+ args = parser.parse_args()
435
+ vid_dir = args.vid_dir
436
+ ds_name = args.ds_name
437
+ load_names = args.load_names
438
+ background_method = args.background_method
439
+ total_gpus = args.total_gpus
440
+ only_bg_img = args.only_bg_img
441
+ mix_bg = not args.no_mix_bg
442
+
443
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
444
+ for d in devices[:total_gpus]:
445
+ os.system(f'pkill -f "voidgpu{d}"')
446
+
447
+ if ds_name.lower() == 'nerf': # 处理单个视频
448
+ vid_names = [vid_dir]
449
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_lms.npy") for video_name in vid_names]
450
+ else: # 处理整个数据集
451
+ if ds_name in ['lrs3_trainval']:
452
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
453
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
454
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
455
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
456
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
457
+ elif ds_name in ["RAVDESS", 'VFHQ']:
458
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
459
+ else:
460
+ raise NotImplementedError()
461
+
462
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
463
+ if os.path.exists(vid_names_path) and load_names:
464
+ print(f"loading vid names from {vid_names_path}")
465
+ vid_names = load_file(vid_names_path)
466
+ else:
467
+ vid_names = multiprocess_glob(vid_name_pattern)
468
+ vid_names = sorted(vid_names)
469
+ print(f"saving vid names to {vid_names_path}")
470
+ save_file(vid_names_path, vid_names)
471
+
472
+ vid_names = sorted(vid_names)
473
+ random.seed(args.seed)
474
+ random.shuffle(vid_names)
475
+
476
+ process_id = args.process_id
477
+ total_process = args.total_process
478
+ if total_process > 1:
479
+ assert process_id <= total_process -1
480
+ num_samples_per_process = len(vid_names) // total_process
481
+ if process_id == total_process:
482
+ vid_names = vid_names[process_id * num_samples_per_process : ]
483
+ else:
484
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
485
+
486
+ if not args.reset:
487
+ vid_names = get_todo_vid_names(vid_names, background_method, only_bg_img)
488
+ print(f"todo videos number: {len(vid_names)}")
489
+ # exit()
490
+
491
+ device = "cuda" if total_gpus > 0 else "cpu"
492
+ if only_bg_img:
493
+ extract_job = extract_bg_img_job
494
+ fn_args = [(vid_name,ds_name=='nerf',i,len(vid_names), background_method, device, total_gpus, mix_bg) for i, vid_name in enumerate(vid_names)]
495
+ else:
496
+ extract_job = extract_segment_job
497
+ fn_args = [(vid_name,ds_name=='nerf',i,len(vid_names), background_method, device, total_gpus, mix_bg) for i, vid_name in enumerate(vid_names)]
498
+
499
+ for vid_name in multiprocess_run_tqdm(extract_job, fn_args, desc=f"Root process {args.process_id}: segment images", num_workers=args.num_workers):
500
+ pass
data_gen/utils/process_video/fit_3dmm_landmark.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a script for efficienct 3DMM coefficient extraction.
2
+ # It could reconstruct accurate 3D face in real-time.
3
+ # It is built upon BFM 2009 model and mediapipe landmark extractor.
4
+ # It is authored by ZhenhuiYe (zhenhuiye@zju.edu.cn), free to contact him for any suggestion on improvement!
5
+
6
+ from numpy.core.numeric import require
7
+ from numpy.lib.function_base import quantile
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import copy
11
+ import numpy as np
12
+
13
+ import random
14
+ import pickle
15
+ import os
16
+ import sys
17
+ import cv2
18
+ import argparse
19
+ import tqdm
20
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
21
+ from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames
22
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
23
+ from deep_3drecon.secc_renderer import SECC_Renderer
24
+ from utils.commons.os_utils import multiprocess_glob
25
+
26
+
27
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
28
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
29
+ face_model.to(torch.device("cuda:0"))
30
+
31
+ dir_path = os.path.dirname(os.path.realpath(__file__))
32
+
33
+
34
+ def draw_axes(img, pitch, yaw, roll, tx, ty, size=50):
35
+ # yaw = -yaw
36
+ pitch = - pitch
37
+ roll = - roll
38
+ rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64)
39
+ axes_points = np.array([
40
+ [1, 0, 0, 0],
41
+ [0, 1, 0, 0],
42
+ [0, 0, 1, 0]
43
+ ], dtype=np.float64)
44
+ axes_points = rotation_matrix @ axes_points
45
+ axes_points = (axes_points[:2, :] * size).astype(int)
46
+ axes_points[0, :] = axes_points[0, :] + tx
47
+ axes_points[1, :] = axes_points[1, :] + ty
48
+
49
+ new_img = img.copy()
50
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3)
51
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3)
52
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3)
53
+ return new_img
54
+
55
+ def save_file(name, content):
56
+ with open(name, "wb") as f:
57
+ pickle.dump(content, f)
58
+
59
+ def load_file(name):
60
+ with open(name, "rb") as f:
61
+ content = pickle.load(f)
62
+ return content
63
+
64
+ def cal_lap_loss(in_tensor):
65
+ # [T, 68, 2]
66
+ t = in_tensor.shape[0]
67
+ in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) # [c, 1, t]
68
+ in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1)
69
+ lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw]
70
+ loss_lap = 0
71
+
72
+ out_tensor = F.conv1d(in_tensor, lap_kernel)
73
+ loss_lap += torch.mean(out_tensor**2)
74
+ return loss_lap
75
+
76
+ def cal_vel_loss(ldm):
77
+ # [B, 68, 2]
78
+ vel = ldm[1:] - ldm[:-1]
79
+ return torch.mean(torch.abs(vel))
80
+
81
+ def cal_lan_loss(proj_lan, gt_lan):
82
+ # [B, 68, 2]
83
+ loss = (proj_lan - gt_lan)** 2
84
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
85
+ weights = torch.zeros_like(loss)
86
+ weights = torch.ones_like(loss)
87
+ weights[:, 36:48, :] = 3 # eye 12 points
88
+ weights[:, -8:, :] = 3 # inner lip 8 points
89
+ weights[:, 28:31, :] = 3 # nose 3 points
90
+ loss = loss * weights
91
+ return torch.mean(loss)
92
+
93
+ def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True):
94
+ # [B, 68, 2]
95
+ loss = (proj_lan - gt_lan).pow(2)
96
+ # loss = (proj_lan - gt_lan).abs()
97
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
98
+ upper_eye = [161,160,159,158,157] + [388,387,386,385,384]
99
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
100
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
101
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
102
+ weights = torch.ones_like(loss)
103
+ weights[:, eye] = 3
104
+ weights[:, upper_eye] = 20
105
+ weights[:, inner_lip] = 5
106
+ weights[:, outer_lip] = 5
107
+ weights[:, unmatch_mask] = 0
108
+ loss = loss * weights
109
+ if mean:
110
+ loss = torch.mean(loss)
111
+ return loss
112
+
113
+ def cal_acceleration_loss(trans):
114
+ vel = trans[1:] - trans[:-1]
115
+ acc = vel[1:] - vel[:-1]
116
+ return torch.mean(torch.abs(acc))
117
+
118
+ def cal_acceleration_ldm_loss(ldm):
119
+ # [B, 68, 2]
120
+ vel = ldm[1:] - ldm[:-1]
121
+ acc = vel[1:] - vel[:-1]
122
+ lip_weight = 0.25 # we dont want smooth the lip too much
123
+ acc[48:68] *= lip_weight
124
+ return torch.mean(torch.abs(acc))
125
+
126
+ def set_requires_grad(tensor_list):
127
+ for tensor in tensor_list:
128
+ tensor.requires_grad = True
129
+
130
+ @torch.enable_grad()
131
+ def fit_3dmm_for_a_video(
132
+ video_name,
133
+ nerf=False, # use the file name convention for GeneFace++
134
+ id_mode='global',
135
+ debug=False,
136
+ keypoint_mode='mediapipe',
137
+ large_yaw_threshold=9999999.9,
138
+ save=True
139
+ ) -> bool: # True: good, False: bad
140
+ assert video_name.endswith(".mp4"), "this function only support video as input"
141
+ if id_mode == 'global':
142
+ LAMBDA_REG_ID = 0.2
143
+ LAMBDA_REG_EXP = 0.6
144
+ LAMBDA_REG_LAP = 1.0
145
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
146
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
147
+ else:
148
+ LAMBDA_REG_ID = 0.3
149
+ LAMBDA_REG_EXP = 0.05
150
+ LAMBDA_REG_LAP = 1.0
151
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
152
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
153
+
154
+ frames = read_video_to_frames(video_name) # [T, H, W, 3]
155
+ img_h, img_w = frames.shape[1], frames.shape[2]
156
+ assert img_h == img_w
157
+ num_frames = len(frames)
158
+
159
+ if nerf: # single video
160
+ lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
161
+ else:
162
+ lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy")
163
+
164
+ if os.path.exists(lm_name):
165
+ lms = np.load(lm_name)
166
+ else:
167
+ print(f"lms_2d file not found, try to extract it from video... {lm_name}")
168
+ try:
169
+ landmarker = MediapipeLandmarker()
170
+ img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20)
171
+ lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
172
+ except Exception as e:
173
+ print(e)
174
+ return False
175
+ if lms is None:
176
+ print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}")
177
+ return False
178
+ lms = lms[:, :468, :]
179
+ lms = torch.FloatTensor(lms).cuda()
180
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
181
+
182
+ if keypoint_mode == 'mediapipe':
183
+ # default
184
+ cal_lan_loss_fn = cal_lan_loss_mp
185
+ if nerf: # single video
186
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy")
187
+ else:
188
+ out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
189
+ else:
190
+ # lm68 is less accurate than mp
191
+ cal_lan_loss_fn = cal_lan_loss
192
+ if nerf: # single video
193
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy")
194
+ else:
195
+ out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy")
196
+ try:
197
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
198
+ except:
199
+ pass
200
+
201
+ id_dim, exp_dim = 80, 64
202
+ sel_ids = np.arange(0, num_frames, 40)
203
+
204
+ h = w = face_model.center * 2
205
+ img_scale_factor = img_h / h
206
+ lms /= img_scale_factor # rescale lms into [0,224]
207
+
208
+ if id_mode == 'global':
209
+ # default choice by GeneFace++ and later works
210
+ id_para = lms.new_zeros((1, id_dim), requires_grad=True)
211
+ elif id_mode == 'finegrained':
212
+ # legacy choice by GeneFace1 (ICLR 2023)
213
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True)
214
+ else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.")
215
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
216
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
217
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
218
+
219
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
220
+
221
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
222
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
223
+
224
+ # 其他参数初始化,先训练euler和trans
225
+ for _ in range(200):
226
+ if id_mode == 'global':
227
+ proj_geo = face_model.compute_for_landmark_fit(
228
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans)
229
+ else:
230
+ proj_geo = face_model.compute_for_landmark_fit(
231
+ id_para, exp_para, euler_angle, trans)
232
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
233
+ loss = loss_lan
234
+ optimizer_frame.zero_grad()
235
+ loss.backward()
236
+ optimizer_frame.step()
237
+
238
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
239
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
240
+
241
+ for param_group in optimizer_frame.param_groups:
242
+ param_group['lr'] = 0.1
243
+
244
+ # "jointly roughly training id exp euler trans"
245
+ for _ in range(200):
246
+ ret = {}
247
+ if id_mode == 'global':
248
+ proj_geo = face_model.compute_for_landmark_fit(
249
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret)
250
+ else:
251
+ proj_geo = face_model.compute_for_landmark_fit(
252
+ id_para, exp_para, euler_angle, trans, ret)
253
+ loss_lan = cal_lan_loss_fn(
254
+ proj_geo[:, :, :2], lms.detach())
255
+ # loss_lap = cal_lap_loss(proj_geo)
256
+ # laplacian对euler影响不大,但是对trans的提升很大
257
+ loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3
258
+
259
+ loss_regid = torch.mean(id_para*id_para) # 正则化
260
+ loss_regexp = torch.mean(exp_para * exp_para)
261
+
262
+ loss_vel_id = cal_vel_loss(id_para)
263
+ loss_vel_exp = cal_vel_loss(exp_para)
264
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP
265
+ optimizer_idexp.zero_grad()
266
+ optimizer_frame.zero_grad()
267
+ loss.backward()
268
+ optimizer_idexp.step()
269
+ optimizer_frame.step()
270
+
271
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
272
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
273
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
274
+
275
+ # start fine training, intialize from the roughly trained results
276
+ if id_mode == 'global':
277
+ id_para_ = lms.new_zeros((1, id_dim), requires_grad=False)
278
+ else:
279
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
280
+ id_para_.data = id_para.data.clone()
281
+ id_para = id_para_
282
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
283
+ exp_para_.data = exp_para.data.clone()
284
+ exp_para = exp_para_
285
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
286
+ euler_angle_.data = euler_angle.data.clone()
287
+ euler_angle = euler_angle_
288
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
289
+ trans_.data = trans.data.clone()
290
+ trans = trans_
291
+
292
+ batch_size = 50
293
+ # "fine fitting the 3DMM in batches"
294
+ for i in range(int((num_frames-1)/batch_size+1)):
295
+ if (i+1)*batch_size > num_frames:
296
+ start_n = num_frames-batch_size
297
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
298
+ else:
299
+ start_n = i*batch_size
300
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
301
+ sel_lms = lms[sel_ids]
302
+
303
+ if id_mode == 'global':
304
+ sel_id_para = id_para.expand((sel_ids.shape[0], id_dim))
305
+ else:
306
+ sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True)
307
+ sel_id_para.data = id_para[sel_ids].clone()
308
+ sel_exp_para = exp_para.new_zeros(
309
+ (batch_size, exp_dim), requires_grad=True)
310
+ sel_exp_para.data = exp_para[sel_ids].clone()
311
+ sel_euler_angle = euler_angle.new_zeros(
312
+ (batch_size, 3), requires_grad=True)
313
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
314
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
315
+ sel_trans.data = trans[sel_ids].clone()
316
+
317
+ if id_mode == 'global':
318
+ set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans])
319
+ optimizer_cur_batch = torch.optim.Adam(
320
+ [sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
321
+ else:
322
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
323
+ optimizer_cur_batch = torch.optim.Adam(
324
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
325
+
326
+ for j in range(50):
327
+ ret = {}
328
+ proj_geo = face_model.compute_for_landmark_fit(
329
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret)
330
+ loss_lan = cal_lan_loss_fn(
331
+ proj_geo[:, :, :2], lms[sel_ids].detach())
332
+
333
+ # loss_lap = cal_lap_loss(proj_geo)
334
+ loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3
335
+ loss_vel_id = cal_vel_loss(sel_id_para)
336
+ loss_vel_exp = cal_vel_loss(sel_exp_para)
337
+ log_dict = {
338
+ 'loss_vel_id': loss_vel_id,
339
+ 'loss_vel_exp': loss_vel_exp,
340
+ 'loss_vel_euler': cal_vel_loss(sel_euler_angle),
341
+ 'loss_vel_trans': cal_vel_loss(sel_trans),
342
+ }
343
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
344
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
345
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP
346
+
347
+ optimizer_cur_batch.zero_grad()
348
+ loss.backward()
349
+ optimizer_cur_batch.step()
350
+
351
+ if debug:
352
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}")
353
+ print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()]))
354
+ if id_mode != 'global':
355
+ id_para[sel_ids].data = sel_id_para.data.clone()
356
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
357
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
358
+ trans[sel_ids].data = sel_trans.data.clone()
359
+
360
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
361
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
362
+
363
+ # filter data by side-view pose
364
+ # bad_yaw = False
365
+ # yaws = [] # not so accurate
366
+ # for index in range(coeff_dict["trans"].shape[0]):
367
+ # yaw = coeff_dict["euler"][index][1]
368
+ # yaw = np.abs(yaw)
369
+ # yaws.append(yaw)
370
+ # if yaw > large_yaw_threshold:
371
+ # bad_yaw = True
372
+
373
+ if debug:
374
+ import imageio
375
+ from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer
376
+ from data_util.face3d_helper import Face3DHelper
377
+ from data_gen.utils.process_video.extract_blink import get_eye_area_percent
378
+ face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe')
379
+
380
+ t = coeff_dict['exp'].shape[0]
381
+ if len(coeff_dict['id']) == 1:
382
+ coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0)
383
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1])
384
+ cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy()
385
+ cano_lm3d = cano_lm3d.reshape([t, -1, 3])
386
+ WH = 512
387
+ cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int)
388
+
389
+ with torch.no_grad():
390
+ rot = ParametricFaceModel.compute_rotation(euler_angle)
391
+ extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device)
392
+ extrinsic[:, :3,:3] = rot
393
+ extrinsic[:, :3, 3] = trans # / 10
394
+ extrinsic[:, 3, 3] = 1
395
+ extrinsic = extrinsic.cpu().numpy()
396
+
397
+ xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy')
398
+ xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz')
399
+
400
+ if nerf:
401
+ debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4")
402
+ else:
403
+ debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4")
404
+ try:
405
+ os.makedirs(os.path.dirname(debug_name), exist_ok=True)
406
+ except: pass
407
+ writer = imageio.get_writer(debug_name, fps=25)
408
+ if id_mode == 'global':
409
+ id_para = id_para.repeat([exp_para.shape[0], 1])
410
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
411
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
412
+ lm68s = lm68s * img_scale_factor
413
+ lms = lms * img_scale_factor
414
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
415
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
416
+ lm68s = lm68s.astype(int)
417
+ for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'):
418
+ xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
419
+ xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512))
420
+ xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
421
+ xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512))
422
+
423
+ img = copy.deepcopy(frames[i])
424
+ img2 = copy.deepcopy(frames[i])
425
+
426
+ img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50)
427
+
428
+ gt_lm_color = (255, 0, 0)
429
+
430
+ for lm in lm68s[i]:
431
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) # blue
432
+ for gt_lm in lms[i]:
433
+ img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1)
434
+
435
+ cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
436
+ for j in range(len(cano_lm3d[i])):
437
+ x, y, _ = cano_lm3d[i, j]
438
+ color = (255,0,0)
439
+ cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1)
440
+ cano_lm3d_img = cv2.flip(cano_lm3d_img, 0)
441
+
442
+ _, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0)
443
+ secc_img = (secc_img +1)*127.5
444
+ secc_img = F.interpolate(secc_img, size=(img_h, img_w))
445
+ secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0]
446
+ out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8)
447
+ font = cv2.FONT_HERSHEY_SIMPLEX
448
+ out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8)
449
+ out_img = np.concatenate([out_img1, out_img2], axis=0)
450
+ writer.append_data(out_img)
451
+ writer.close()
452
+
453
+ # if bad_yaw:
454
+ # print(f"Skip {video_name} due to TOO LARGE YAW")
455
+ # return False
456
+
457
+ if save:
458
+ np.save(out_name, coeff_dict, allow_pickle=True)
459
+ return coeff_dict
460
+
461
+ def out_exist_job(vid_name):
462
+ out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy")
463
+ lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
464
+ if os.path.exists(out_name) or not os.path.exists(lms_name):
465
+ return None
466
+ else:
467
+ return vid_name
468
+
469
+ def get_todo_vid_names(vid_names):
470
+ if len(vid_names) == 1: # single video, nerf
471
+ return vid_names
472
+ todo_vid_names = []
473
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
474
+ if res is not None:
475
+ todo_vid_names.append(res)
476
+ return todo_vid_names
477
+
478
+
479
+ if __name__ == '__main__':
480
+ import argparse, glob, tqdm
481
+ parser = argparse.ArgumentParser()
482
+ # parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
483
+ parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4')
484
+ parser.add_argument("--ds_name", default='nerf') # 'nerf' | 'CelebV-HQ' | 'TH1KH_512' | etc
485
+ parser.add_argument("--seed", default=0, type=int)
486
+ parser.add_argument("--process_id", default=0, type=int)
487
+ parser.add_argument("--total_process", default=1, type=int)
488
+ parser.add_argument("--id_mode", default='global', type=str) # global | finegrained
489
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
490
+ parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) # could be 0.7
491
+ parser.add_argument("--debug", action='store_true')
492
+ parser.add_argument("--reset", action='store_true')
493
+ parser.add_argument("--load_names", action="store_true")
494
+
495
+ args = parser.parse_args()
496
+ vid_dir = args.vid_dir
497
+ ds_name = args.ds_name
498
+ load_names = args.load_names
499
+
500
+ print(f"args {args}")
501
+
502
+ if ds_name.lower() == 'nerf': # 处理单个视频
503
+ vid_names = [vid_dir]
504
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
505
+ else: # 处理整个数据集
506
+ if ds_name in ['lrs3_trainval']:
507
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
508
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
509
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
510
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
511
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
512
+ elif ds_name in ["RAVDESS", 'VFHQ']:
513
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
514
+ else:
515
+ raise NotImplementedError()
516
+
517
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
518
+ if os.path.exists(vid_names_path) and load_names:
519
+ print(f"loading vid names from {vid_names_path}")
520
+ vid_names = load_file(vid_names_path)
521
+ else:
522
+ vid_names = multiprocess_glob(vid_name_pattern)
523
+ vid_names = sorted(vid_names)
524
+ print(f"saving vid names to {vid_names_path}")
525
+ save_file(vid_names_path, vid_names)
526
+ out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
527
+
528
+ print(vid_names[:10])
529
+ random.seed(args.seed)
530
+ random.shuffle(vid_names)
531
+
532
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
533
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
534
+ face_model.to(torch.device("cuda:0"))
535
+ secc_renderer = SECC_Renderer(512)
536
+ secc_renderer.to("cuda:0")
537
+
538
+ process_id = args.process_id
539
+ total_process = args.total_process
540
+ if total_process > 1:
541
+ assert process_id <= total_process -1
542
+ num_samples_per_process = len(vid_names) // total_process
543
+ if process_id == total_process:
544
+ vid_names = vid_names[process_id * num_samples_per_process : ]
545
+ else:
546
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
547
+
548
+ if not args.reset:
549
+ vid_names = get_todo_vid_names(vid_names)
550
+
551
+ failed_img_names = []
552
+ for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."):
553
+ img_name = vid_names[i]
554
+ try:
555
+ is_person_specific_data = ds_name=='nerf'
556
+ success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold)
557
+ if not success:
558
+ failed_img_names.append(img_name)
559
+ except Exception as e:
560
+ print(img_name, e)
561
+ failed_img_names.append(img_name)
562
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}")
563
+ sys.stdout.flush()
564
+ print(f"all failed image names: {failed_img_names}")
565
+ print(f"All finished!")
data_gen/utils/process_video/inpaint_torso_imgs.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
5
+ from scipy.ndimage import binary_erosion, binary_dilation
6
+
7
+ from tasks.eg3ds.loss_utils.segment_loss.mp_segmenter import MediapipeSegmenter
8
+ seg_model = MediapipeSegmenter()
9
+
10
+ def inpaint_torso_job(video_name, idx=None, total=None):
11
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/","/gt_imgs/")
12
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
13
+
14
+ for image_path in tqdm.tqdm(img_names):
15
+ # read ori image
16
+ ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
17
+ segmap = seg_model._cal_seg_map(cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB))
18
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(np.bool)
19
+ torso_part = (segmap[4]).astype(np.bool)
20
+ neck_part = (segmap[2]).astype(np.bool)
21
+ bg_part = segmap[0].astype(np.bool)
22
+ head_image = cv2.imread(image_path.replace("/gt_imgs/", "/head_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
23
+ torso_image = cv2.imread(image_path.replace("/gt_imgs/", "/torso_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
24
+ bg_image = cv2.imread(image_path.replace("/gt_imgs/", "/bg_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
25
+
26
+ # head_part = (head_image[...,0] != 0) & (head_image[...,1] != 0) & (head_image[...,2] != 0)
27
+ # torso_part = (torso_image[...,0] != 0) & (torso_image[...,1] != 0) & (torso_image[...,2] != 0)
28
+ # bg_part = (bg_image[...,0] != 0) & (bg_image[...,1] != 0) & (bg_image[...,2] != 0)
29
+
30
+ # get gt image
31
+ gt_image = ori_image.copy()
32
+ gt_image[bg_part] = bg_image[bg_part]
33
+ cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
34
+
35
+ # get torso image
36
+ torso_image = gt_image.copy() # rgb
37
+ torso_image[head_part] = 0
38
+ torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
39
+
40
+ # torso part "vertical" in-painting...
41
+ L = 8 + 1
42
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
43
+ # lexsort: sort 2D coords first by y then by x,
44
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
45
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
46
+ torso_coords = torso_coords[inds]
47
+ # choose the top pixel for each column
48
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
49
+ top_torso_coords = torso_coords[uid] # [m, 2]
50
+ # only keep top-is-head pixels
51
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
52
+ mask = head_part[tuple(top_torso_coords_up.T)]
53
+ if mask.any():
54
+ top_torso_coords = top_torso_coords[mask]
55
+ # get the color
56
+ top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
57
+ # construct inpaint coords (vertically up, or minus in x)
58
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
59
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
60
+ inpaint_torso_coords += inpaint_offsets
61
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
62
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
63
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
64
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
65
+ # set color
66
+ torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
67
+
68
+ inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
69
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
70
+ else:
71
+ inpaint_torso_mask = None
72
+
73
+ # neck part "vertical" in-painting...
74
+ push_down = 4
75
+ L = 48 + push_down + 1
76
+
77
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
78
+
79
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
80
+ # lexsort: sort 2D coords first by y then by x,
81
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
82
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
83
+ neck_coords = neck_coords[inds]
84
+ # choose the top pixel for each column
85
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
86
+ top_neck_coords = neck_coords[uid] # [m, 2]
87
+ # only keep top-is-head pixels
88
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
89
+ mask = head_part[tuple(top_neck_coords_up.T)]
90
+
91
+ top_neck_coords = top_neck_coords[mask]
92
+ # push these top down for 4 pixels to make the neck inpainting more natural...
93
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
94
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
95
+ # get the color
96
+ top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
97
+ # construct inpaint coords (vertically up, or minus in x)
98
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
99
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
100
+ inpaint_neck_coords += inpaint_offsets
101
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
102
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
103
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
104
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
105
+ # set color
106
+ torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
107
+
108
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
109
+ inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
110
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
111
+
112
+ blur_img = torso_image.copy()
113
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
114
+
115
+ torso_image[inpaint_mask] = blur_img[inpaint_mask]
116
+
117
+ # set mask
118
+ mask = (neck_part | torso_part | inpaint_mask)
119
+ if inpaint_torso_mask is not None:
120
+ mask = mask | inpaint_torso_mask
121
+ torso_image[~mask] = 0
122
+ torso_alpha[~mask] = 0
123
+
124
+ cv2.imwrite("0.png", np.concatenate([torso_image, torso_alpha], axis=-1))
125
+
126
+ print(f'[INFO] ===== extracted torso and gt images =====')
127
+
128
+
129
+ def out_exist_job(vid_name):
130
+ out_dir1 = vid_name.replace("/video/", "/inpaint_torso_imgs/").replace(".mp4","")
131
+ out_dir2 = vid_name.replace("/video/", "/inpaint_torso_with_bg_imgs/").replace(".mp4","")
132
+ out_dir3 = vid_name.replace("/video/", "/torso_imgs/").replace(".mp4","")
133
+ out_dir4 = vid_name.replace("/video/", "/torso_with_bg_imgs/").replace(".mp4","")
134
+
135
+ if os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) and os.path.exists(out_dir3) and os.path.exists(out_dir4):
136
+ num_frames = len(os.listdir(out_dir1))
137
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames and len(os.listdir(out_dir3)) == num_frames and len(os.listdir(out_dir4)) == num_frames:
138
+ return None
139
+ else:
140
+ return vid_name
141
+ else:
142
+ return vid_name
143
+
144
+ def get_todo_vid_names(vid_names):
145
+ todo_vid_names = []
146
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
147
+ if res is not None:
148
+ todo_vid_names.append(res)
149
+ return todo_vid_names
150
+
151
+ if __name__ == '__main__':
152
+ import argparse, glob, tqdm, random
153
+ parser = argparse.ArgumentParser()
154
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
155
+ parser.add_argument("--ds_name", default='CelebV-HQ')
156
+ parser.add_argument("--num_workers", default=48, type=int)
157
+ parser.add_argument("--seed", default=0, type=int)
158
+ parser.add_argument("--process_id", default=0, type=int)
159
+ parser.add_argument("--total_process", default=1, type=int)
160
+ parser.add_argument("--reset", action='store_true')
161
+
162
+ inpaint_torso_job('/home/tiger/datasets/raw/CelebV-HQ/video/dgdEr-mXQT4_8.mp4')
163
+ # args = parser.parse_args()
164
+ # vid_dir = args.vid_dir
165
+ # ds_name = args.ds_name
166
+ # if ds_name in ['lrs3_trainval']:
167
+ # mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
168
+ # if ds_name in ['TH1KH_512', 'CelebV-HQ']:
169
+ # vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
170
+ # elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
171
+ # vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
172
+ # vid_names = glob.glob(vid_name_pattern)
173
+ # vid_names = sorted(vid_names)
174
+ # random.seed(args.seed)
175
+ # random.shuffle(vid_names)
176
+
177
+ # process_id = args.process_id
178
+ # total_process = args.total_process
179
+ # if total_process > 1:
180
+ # assert process_id <= total_process -1
181
+ # num_samples_per_process = len(vid_names) // total_process
182
+ # if process_id == total_process:
183
+ # vid_names = vid_names[process_id * num_samples_per_process : ]
184
+ # else:
185
+ # vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
186
+
187
+ # if not args.reset:
188
+ # vid_names = get_todo_vid_names(vid_names)
189
+ # print(f"todo videos number: {len(vid_names)}")
190
+
191
+ # fn_args = [(vid_name,i,len(vid_names)) for i, vid_name in enumerate(vid_names)]
192
+ # for vid_name in multiprocess_run_tqdm(inpaint_torso_job ,fn_args, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
193
+ # pass
data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ import cv2
3
+ from utils.commons.os_utils import multiprocess_glob
4
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
5
+
6
+ def get_video_infos(video_path):
7
+ vid_cap = cv2.VideoCapture(video_path)
8
+ height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
9
+ width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
10
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
11
+ total_frames = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
12
+ return {'height': height, 'width': width, 'fps': fps, 'total_frames':total_frames}
13
+
14
+ def extract_img_job(video_name:str):
15
+ out_path = video_name.replace("/video_raw/","/video/",1)
16
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
17
+ ffmpeg_path = "/usr/bin/ffmpeg"
18
+ vid_info = get_video_infos(video_name)
19
+ assert vid_info['width'] == vid_info['height']
20
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
21
+ os.system(cmd)
22
+
23
+ def extract_img_job_crop(video_name:str):
24
+ out_path = video_name.replace("/video_raw/","/video/",1)
25
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
26
+ ffmpeg_path = "/usr/bin/ffmpeg"
27
+ vid_info = get_video_infos(video_name)
28
+ wh = min(vid_info['width'], vid_info['height'])
29
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop={wh}:{wh},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
30
+ os.system(cmd)
31
+
32
+ def extract_img_job_crop_ravdess(video_name:str):
33
+ out_path = video_name.replace("/video_raw/","/video/",1)
34
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
35
+ ffmpeg_path = "/usr/bin/ffmpeg"
36
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop=720:720,scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
37
+ os.system(cmd)
38
+
39
+ if __name__ == '__main__':
40
+ import argparse, glob, tqdm, random
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video_raw/')
43
+ parser.add_argument("--ds_name", default='CelebV-HQ')
44
+ parser.add_argument("--num_workers", default=32, type=int)
45
+ parser.add_argument("--process_id", default=0, type=int)
46
+ parser.add_argument("--total_process", default=1, type=int)
47
+ args = parser.parse_args()
48
+ print(f"args {args}")
49
+
50
+ vid_dir = args.vid_dir
51
+ ds_name = args.ds_name
52
+ if ds_name in ['lrs3_trainval']:
53
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
54
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
55
+ vid_names = multiprocess_glob(os.path.join(vid_dir, "*.mp4"))
56
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
57
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
58
+ vid_names = multiprocess_glob(vid_name_pattern)
59
+ elif ds_name in ["RAVDESS", 'VFHQ']:
60
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
61
+ vid_names = multiprocess_glob(vid_name_pattern)
62
+ else:
63
+ raise NotImplementedError()
64
+ vid_names = sorted(vid_names)
65
+ print(f"total video number : {len(vid_names)}")
66
+ print(f"first {vid_names[0]} last {vid_names[-1]}")
67
+ # exit()
68
+ process_id = args.process_id
69
+ total_process = args.total_process
70
+ if total_process > 1:
71
+ assert process_id <= total_process -1
72
+ num_samples_per_process = len(vid_names) // total_process
73
+ if process_id == total_process:
74
+ vid_names = vid_names[process_id * num_samples_per_process : ]
75
+ else:
76
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
77
+
78
+ if ds_name == "RAVDESS":
79
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop_ravdess, vid_names, num_workers=args.num_workers, desc="resampling videos"):
80
+ pass
81
+ elif ds_name == "CMLR":
82
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop, vid_names, num_workers=args.num_workers, desc="resampling videos"):
83
+ pass
84
+ else:
85
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="resampling videos"):
86
+ pass
87
+
data_gen/utils/process_video/split_video_to_imgs.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
3
+
4
+ from data_gen.utils.path_converter import PathConverter, pc
5
+
6
+ # mp4_names = glob.glob("/home/tiger/datasets/raw/CelebV-HQ/video/*.mp4")
7
+
8
+ def extract_img_job(video_name, raw_img_dir=None):
9
+ if raw_img_dir is not None:
10
+ out_path = raw_img_dir
11
+ else:
12
+ out_path = pc.to(video_name.replace(".mp4", ""), "vid", "gt")
13
+ os.makedirs(out_path, exist_ok=True)
14
+ ffmpeg_path = "/usr/bin/ffmpeg"
15
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet {os.path.join(out_path, "%8d.jpg")}'
16
+ os.system(cmd)
17
+
18
+ if __name__ == '__main__':
19
+ import argparse, glob, tqdm, random
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
22
+ parser.add_argument("--ds_name", default='CelebV-HQ')
23
+ parser.add_argument("--num_workers", default=64, type=int)
24
+ parser.add_argument("--process_id", default=0, type=int)
25
+ parser.add_argument("--total_process", default=1, type=int)
26
+ args = parser.parse_args()
27
+ vid_dir = args.vid_dir
28
+ ds_name = args.ds_name
29
+ if ds_name in ['lrs3_trainval']:
30
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
31
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
32
+ vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
33
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
34
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
35
+ vid_names = glob.glob(vid_name_pattern)
36
+ elif ds_name in ["RAVDESS", 'VFHQ']:
37
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
38
+ vid_names = glob.glob(vid_name_pattern)
39
+ vid_names = sorted(vid_names)
40
+
41
+ process_id = args.process_id
42
+ total_process = args.total_process
43
+ if total_process > 1:
44
+ assert process_id <= total_process -1
45
+ num_samples_per_process = len(vid_names) // total_process
46
+ if process_id == total_process:
47
+ vid_names = vid_names[process_id * num_samples_per_process : ]
48
+ else:
49
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
50
+
51
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="extracting images"):
52
+ pass
53
+
data_util/face3d_helper.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from scipy.io import loadmat
6
+
7
+ from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
8
+
9
+
10
+ class Face3DHelper(nn.Module):
11
+ def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
12
+ super().__init__()
13
+ self.keypoint_mode = keypoint_mode # lm68 | mediapipe
14
+ self.bfm_dir = bfm_dir
15
+ self.load_3dmm()
16
+ if use_gpu: self.to("cuda")
17
+
18
+ def load_3dmm(self):
19
+ model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
20
+ self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
21
+ mean_shape = self.mean_shape.reshape([-1, 3])
22
+ # re-center
23
+ mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
24
+ self.mean_shape = mean_shape.reshape([-1, 1])
25
+ self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
26
+ self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
27
+
28
+ self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
29
+ self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
30
+
31
+ self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
32
+ self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
33
+ if self.keypoint_mode == 'mediapipe':
34
+ self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
35
+ unmatch_mask = self.key_points < 0
36
+ self.key_points[unmatch_mask] = 0
37
+ else:
38
+ self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
39
+
40
+
41
+ self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
42
+ self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
43
+ self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
44
+ self.key_id_base_np = self.key_id_base.cpu().numpy()
45
+ self.key_exp_base_np = self.key_exp_base.cpu().numpy()
46
+
47
+ self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
48
+ def split_coeff(self, coeff):
49
+ """
50
+ coeff: Tensor[B, T, c=257] or [T, c=257]
51
+ """
52
+ ret_dict = {
53
+ 'identity': coeff[..., :80], # identity, [b, t, c=80]
54
+ 'expression': coeff[..., 80:144], # expression, [b, t, c=80]
55
+ 'texture': coeff[..., 144:224], # texture, [b, t, c=80]
56
+ 'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
57
+ 'translation': coeff[..., 254:257], # translation, [b, t, c=3]
58
+ 'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
59
+ }
60
+ return ret_dict
61
+
62
+ def reconstruct_face_mesh(self, id_coeff, exp_coeff):
63
+ """
64
+ Generate a pose-independent 3D face mesh!
65
+ id_coeff: Tensor[T, c=80]
66
+ exp_coeff: Tensor[T, c=64]
67
+ """
68
+ id_coeff = id_coeff.to(self.key_id_base.device)
69
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
70
+ mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
71
+ id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
72
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
73
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
74
+
75
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
76
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
77
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
78
+ # mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
79
+ # face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
80
+ return face
81
+
82
+ def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
83
+ """
84
+ Generate 3D landmark with keypoint base!
85
+ id_coeff: Tensor[T, c=80]
86
+ exp_coeff: Tensor[T, c=64]
87
+ """
88
+ id_coeff = id_coeff.to(self.key_id_base.device)
89
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
90
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
91
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
92
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
93
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
94
+
95
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
96
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
97
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
98
+ # mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
99
+ # lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
100
+ return face
101
+
102
+ def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
103
+ """
104
+ Generate 3D landmark with keypoint base!
105
+ id_coeff: Tensor[T, c=80]
106
+ exp_coeff: Tensor[T, c=64]
107
+ """
108
+ id_coeff = id_coeff.to(self.key_id_base.device)
109
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
110
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
111
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
112
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
113
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
114
+
115
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
116
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
117
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
118
+ rot = self.compute_rotation(euler)
119
+ # transform
120
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
121
+ # to camera
122
+ if to_camera:
123
+ lm3d[...,-1] = 10 - lm3d[...,-1]
124
+ return lm3d
125
+
126
+ def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
127
+ lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
128
+ lm2d[..., 0] = 1 - lm2d[..., 0]
129
+ lm2d[..., 1] = 1 - lm2d[..., 1]
130
+ return lm2d
131
+
132
+ def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
133
+ """
134
+ Generate 3D landmark with keypoint base!
135
+ id_coeff: Tensor[T, c=80]
136
+ exp_coeff: Tensor[T, c=64]
137
+ """
138
+ is_btc_flag = True if id_coeff.ndim == 3 else False
139
+ if is_btc_flag:
140
+ b,t,_ = id_coeff.shape
141
+ id_coeff = id_coeff.reshape([b*t,-1])
142
+ exp_coeff = exp_coeff.reshape([b*t,-1])
143
+ euler = euler.reshape([b*t,-1])
144
+ trans = trans.reshape([b*t,-1])
145
+ id_coeff = id_coeff.to(self.key_id_base.device)
146
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
147
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
148
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
149
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
150
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
151
+
152
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
153
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
154
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
155
+ rot = self.compute_rotation(euler)
156
+ # transform
157
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
158
+ # to camera
159
+ if to_camera:
160
+ lm3d[...,-1] = 10 - lm3d[...,-1]
161
+ # to image_plane
162
+ lm3d = lm3d @ self.persc_proj
163
+ lm2d = lm3d[..., :2] / lm3d[..., 2:]
164
+ # flip
165
+ lm2d[..., 1] = 224 - lm2d[..., 1]
166
+ lm2d /= 224
167
+ if is_btc_flag:
168
+ return lm2d.reshape([b,t,-1,2])
169
+ return lm2d
170
+
171
+ def compute_rotation(self, euler):
172
+ """
173
+ Return:
174
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
175
+
176
+ Parameters:
177
+ euler -- torch.tensor, size (B, 3), radian
178
+ """
179
+
180
+ batch_size = euler.shape[0]
181
+ euler = euler.to(self.key_id_base.device)
182
+ ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
183
+ zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
184
+ x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
185
+
186
+ rot_x = torch.cat([
187
+ ones, zeros, zeros,
188
+ zeros, torch.cos(x), -torch.sin(x),
189
+ zeros, torch.sin(x), torch.cos(x)
190
+ ], dim=1).reshape([batch_size, 3, 3])
191
+
192
+ rot_y = torch.cat([
193
+ torch.cos(y), zeros, torch.sin(y),
194
+ zeros, ones, zeros,
195
+ -torch.sin(y), zeros, torch.cos(y)
196
+ ], dim=1).reshape([batch_size, 3, 3])
197
+
198
+ rot_z = torch.cat([
199
+ torch.cos(z), -torch.sin(z), zeros,
200
+ torch.sin(z), torch.cos(z), zeros,
201
+ zeros, zeros, ones
202
+ ], dim=1).reshape([batch_size, 3, 3])
203
+
204
+ rot = rot_z @ rot_y @ rot_x
205
+ return rot.permute(0, 2, 1)
206
+
207
+ def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
208
+ """
209
+ Generate 3D landmark with keypoint base!
210
+ id_coeff: Tensor[T, c=80]
211
+ exp_coeff: Tensor[T, c=64]
212
+ """
213
+ id_coeff = id_coeff.to(self.key_id_base.device)
214
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
215
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
216
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
217
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
218
+
219
+ face = identity_diff_face + expression_diff_face # [t,3N]
220
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
221
+ lm3d = face * 10
222
+ return lm3d
223
+
224
+ def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
225
+ """
226
+ Generate 3D landmark with keypoint base!
227
+ id_coeff: Tensor[T, c=80]
228
+ exp_coeff: Tensor[T, c=64]
229
+ """
230
+ id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
231
+ identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
232
+ expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
233
+
234
+ face = identity_diff_face + expression_diff_face # [t,3N]
235
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
236
+ lm3d = face * 10
237
+ return lm3d
238
+
239
+ def get_eye_mouth_lm_from_lm3d(self, lm3d):
240
+ eye_lm = lm3d[:, 17:48] # [T, 31, 3]
241
+ mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
242
+ return eye_lm, mouth_lm
243
+
244
+ def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
245
+ eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
246
+ mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
247
+ return eye_lm, mouth_lm
248
+
249
+ def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
250
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
251
+ num_frames = idexp_lm3d.shape[0]
252
+ eps = 0.0
253
+ # [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
254
+ idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
255
+ idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
256
+
257
+ idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
258
+ idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
259
+
260
+ idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
261
+ idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
262
+
263
+ if freeze_as_first_frame:
264
+ idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
265
+ return idexp_lm3d.cpu()
266
+
267
+ def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
268
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
269
+ eps = 0.003
270
+ idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
271
+ idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
272
+
273
+ idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
274
+ idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
275
+
276
+ return idexp_lm3d
277
+
278
+ if __name__ == '__main__':
279
+ import cv2
280
+
281
+ font = cv2.FONT_HERSHEY_SIMPLEX
282
+
283
+ face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
284
+ coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
285
+ coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
286
+ lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
287
+
288
+ WH = 512
289
+ lm3d = (lm3d * WH).cpu().int().numpy()
290
+ eye_idx = list(range(36,48))
291
+ mouth_idx = list(range(48,68))
292
+ import imageio
293
+ debug_name = 'debug_lm3d.mp4'
294
+ writer = imageio.get_writer(debug_name, fps=25)
295
+ for i_img in range(len(lm3d)):
296
+ lm2d = lm3d[i_img ,:, :2] # [68, 2]
297
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
298
+ for i in range(len(lm2d)):
299
+ x, y = lm2d[i]
300
+ if i in eye_idx:
301
+ color = (0,0,255)
302
+ elif i in mouth_idx:
303
+ color = (0,255,0)
304
+ else:
305
+ color = (255,0,0)
306
+ img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
307
+ img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
308
+ writer.append_data(img)
309
+ writer.close()
deep_3drecon/BFM/.gitkeep ADDED
File without changes
deep_3drecon/bfm_left_eye_faces.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9651756ea2c0fac069a1edf858ed1f125eddc358fa74c529a370c1e7b5730d28
3
+ size 4680
deep_3drecon/bfm_right_eye_faces.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28cb5bbacf578d30a3d5006ec28c617fe5a3ecaeeeb87d9433a884e0f0301a2e
3
+ size 4648
deep_3drecon/deep_3drecon_models/bfm.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from scipy.io import loadmat
8
+ import os
9
+ # from utils.commons.tensor_utils import convert_like
10
+
11
+
12
+ def perspective_projection(focal, center):
13
+ # return p.T (N, 3) @ (3, 3)
14
+ return np.array([
15
+ focal, 0, center,
16
+ 0, focal, center,
17
+ 0, 0, 1
18
+ ]).reshape([3, 3]).astype(np.float32).transpose() # 注意这里的transpose!
19
+
20
+ class SH:
21
+ def __init__(self):
22
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
23
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
24
+
25
+
26
+
27
+ class ParametricFaceModel:
28
+ def __init__(self,
29
+ bfm_folder='./BFM',
30
+ recenter=True,
31
+ camera_distance=10.,
32
+ init_lit=np.array([
33
+ 0.8, 0, 0, 0, 0, 0, 0, 0, 0
34
+ ]),
35
+ focal=1015.,
36
+ center=112.,
37
+ is_train=True,
38
+ default_name='BFM_model_front.mat',
39
+ keypoint_mode='mediapipe'):
40
+
41
+ model = loadmat(os.path.join(bfm_folder, default_name))
42
+ # mean face shape. [3*N,1]
43
+ self.mean_shape = model['meanshape'].astype(np.float32)
44
+ # identity basis. [3*N,80]
45
+ self.id_base = model['idBase'].astype(np.float32)
46
+ # expression basis. [3*N,64]
47
+ self.exp_base = model['exBase'].astype(np.float32)
48
+ # mean face texture. [3*N,1] (0-255)
49
+ self.mean_tex = model['meantex'].astype(np.float32)
50
+ # texture basis. [3*N,80]
51
+ self.tex_base = model['texBase'].astype(np.float32)
52
+ # face indices for each vertex that lies in. starts from 0. [N,8]
53
+ self.point_buf = model['point_buf'].astype(np.int64) - 1
54
+ # vertex indices for each face. starts from 0. [F,3]
55
+ self.face_buf = model['tri'].astype(np.int64) - 1
56
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
57
+ if keypoint_mode == 'mediapipe':
58
+ self.keypoints = np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)
59
+ unmatch_mask = self.keypoints < 0
60
+ self.keypoints[unmatch_mask] = 0
61
+ else:
62
+ self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
63
+
64
+ if is_train:
65
+ # vertex indices for small face region to compute photometric error. starts from 0.
66
+ self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
67
+ # vertex indices for each face from small face region. starts from 0. [f,3]
68
+ self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
69
+ # vertex indices for pre-defined skin region to compute reflectance loss
70
+ self.skin_mask = np.squeeze(model['skinmask'])
71
+
72
+ if recenter:
73
+ mean_shape = self.mean_shape.reshape([-1, 3])
74
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
75
+ self.mean_shape = mean_shape.reshape([-1, 1])
76
+
77
+ self.key_mean_shape = self.mean_shape.reshape([-1, 3])[self.keypoints, :].reshape([-1, 3])
78
+ self.key_id_base = self.id_base.reshape([-1, 3,80])[self.keypoints, :].reshape([-1, 80])
79
+ self.key_exp_base = self.exp_base.reshape([-1, 3, 64])[self.keypoints, :].reshape([-1, 64])
80
+
81
+ self.focal = focal
82
+ self.center = center
83
+ self.persc_proj = perspective_projection(focal, center)
84
+ self.device = 'cpu'
85
+ self.camera_distance = camera_distance
86
+ self.SH = SH()
87
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
88
+
89
+ self.initialized = False
90
+
91
+ def to(self, device):
92
+ self.device = device
93
+ for key, value in self.__dict__.items():
94
+ if type(value).__module__ == np.__name__:
95
+ setattr(self, key, torch.tensor(value).to(device))
96
+ self.initialized = True
97
+ return self
98
+
99
+ def compute_shape(self, id_coeff, exp_coeff):
100
+ """
101
+ Return:
102
+ face_shape -- torch.tensor, size (B, N, 3)
103
+
104
+ Parameters:
105
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
106
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
107
+ """
108
+ batch_size = id_coeff.shape[0]
109
+ id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
110
+ exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
111
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
112
+ return face_shape.reshape([batch_size, -1, 3])
113
+
114
+ def compute_key_shape(self, id_coeff, exp_coeff):
115
+ """
116
+ Return:
117
+ face_shape -- torch.tensor, size (B, N, 3)
118
+
119
+ Parameters:
120
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
121
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
122
+ """
123
+ batch_size = id_coeff.shape[0]
124
+ id_part = torch.einsum('ij,aj->ai', self.key_id_base, id_coeff)
125
+ exp_part = torch.einsum('ij,aj->ai', self.key_exp_base, exp_coeff)
126
+ face_shape = id_part + exp_part + self.key_mean_shape.reshape([1, -1])
127
+ return face_shape.reshape([batch_size, -1, 3])
128
+
129
+ def compute_texture(self, tex_coeff, normalize=True):
130
+ """
131
+ Return:
132
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
133
+
134
+ Parameters:
135
+ tex_coeff -- torch.tensor, size (B, 80)
136
+ """
137
+ batch_size = tex_coeff.shape[0]
138
+ face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
139
+ if normalize:
140
+ face_texture = face_texture / 255.
141
+ return face_texture.reshape([batch_size, -1, 3])
142
+
143
+
144
+ def compute_norm(self, face_shape):
145
+ """
146
+ Return:
147
+ vertex_norm -- torch.tensor, size (B, N, 3)
148
+
149
+ Parameters:
150
+ face_shape -- torch.tensor, size (B, N, 3)
151
+ """
152
+
153
+ v1 = face_shape[:, self.face_buf[:, 0]]
154
+ v2 = face_shape[:, self.face_buf[:, 1]]
155
+ v3 = face_shape[:, self.face_buf[:, 2]]
156
+ e1 = v1 - v2
157
+ e2 = v2 - v3
158
+ face_norm = torch.cross(e1, e2, dim=-1)
159
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
160
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
161
+
162
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
163
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
164
+ return vertex_norm
165
+
166
+
167
+ def compute_color(self, face_texture, face_norm, gamma):
168
+ """
169
+ Return:
170
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
171
+
172
+ Parameters:
173
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
174
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
175
+ gamma -- torch.tensor, size (B, 27), SH coeffs
176
+ """
177
+ batch_size = gamma.shape[0]
178
+ v_num = face_texture.shape[1]
179
+ a, c = self.SH.a, self.SH.c
180
+ gamma = gamma.reshape([batch_size, 3, 9])
181
+ gamma = gamma + self.init_lit
182
+ gamma = gamma.permute(0, 2, 1)
183
+ Y = torch.cat([
184
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
185
+ -a[1] * c[1] * face_norm[..., 1:2],
186
+ a[1] * c[1] * face_norm[..., 2:],
187
+ -a[1] * c[1] * face_norm[..., :1],
188
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
189
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
190
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
191
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
192
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
193
+ ], dim=-1)
194
+ r = Y @ gamma[..., :1]
195
+ g = Y @ gamma[..., 1:2]
196
+ b = Y @ gamma[..., 2:]
197
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
198
+ return face_color
199
+
200
+ @staticmethod
201
+ def compute_rotation(angles, device='cpu'):
202
+ """
203
+ Return:
204
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
205
+
206
+ Parameters:
207
+ angles -- torch.tensor, size (B, 3), radian
208
+ """
209
+
210
+ batch_size = angles.shape[0]
211
+ angles = angles.to(device)
212
+ ones = torch.ones([batch_size, 1]).to(device)
213
+ zeros = torch.zeros([batch_size, 1]).to(device)
214
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
215
+
216
+ rot_x = torch.cat([
217
+ ones, zeros, zeros,
218
+ zeros, torch.cos(x), -torch.sin(x),
219
+ zeros, torch.sin(x), torch.cos(x)
220
+ ], dim=1).reshape([batch_size, 3, 3])
221
+
222
+ rot_y = torch.cat([
223
+ torch.cos(y), zeros, torch.sin(y),
224
+ zeros, ones, zeros,
225
+ -torch.sin(y), zeros, torch.cos(y)
226
+ ], dim=1).reshape([batch_size, 3, 3])
227
+
228
+ rot_z = torch.cat([
229
+ torch.cos(z), -torch.sin(z), zeros,
230
+ torch.sin(z), torch.cos(z), zeros,
231
+ zeros, zeros, ones
232
+ ], dim=1).reshape([batch_size, 3, 3])
233
+
234
+ rot = rot_z @ rot_y @ rot_x
235
+ return rot.permute(0, 2, 1)
236
+
237
+
238
+ def to_camera(self, face_shape):
239
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1] # reverse the depth axis, add a fixed offset of length
240
+ return face_shape
241
+
242
+ def to_image(self, face_shape):
243
+ """
244
+ Return:
245
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
246
+
247
+ Parameters:
248
+ face_shape -- torch.tensor, size (B, N, 3)
249
+ """
250
+ # to image_plane
251
+ face_proj = face_shape @ self.persc_proj
252
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
253
+
254
+ return face_proj
255
+
256
+
257
+ def transform(self, face_shape, rot, trans):
258
+ """
259
+ Return:
260
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
261
+
262
+ Parameters:
263
+ face_shape -- torch.tensor, si≥ze (B, N, 3)
264
+ rot -- torch.tensor, size (B, 3, 3)
265
+ trans -- torch.tensor, size (B, 3)
266
+ """
267
+ return face_shape @ rot + trans.unsqueeze(1)
268
+
269
+
270
+ def get_landmarks(self, face_proj):
271
+ """
272
+ Return:
273
+ face_lms -- torch.tensor, size (B, 68, 2)
274
+
275
+ Parameters:
276
+ face_proj -- torch.tensor, size (B, N, 2)
277
+ """
278
+ return face_proj[:, self.keypoints]
279
+
280
+ def split_coeff(self, coeffs):
281
+ """
282
+ Return:
283
+ coeffs_dict -- a dict of torch.tensors
284
+
285
+ Parameters:
286
+ coeffs -- torch.tensor, size (B, 256)
287
+ """
288
+ id_coeffs = coeffs[:, :80]
289
+ exp_coeffs = coeffs[:, 80: 144]
290
+ tex_coeffs = coeffs[:, 144: 224]
291
+ angles = coeffs[:, 224: 227]
292
+ gammas = coeffs[:, 227: 254]
293
+ translations = coeffs[:, 254:]
294
+ return {
295
+ 'id': id_coeffs,
296
+ 'exp': exp_coeffs,
297
+ 'tex': tex_coeffs,
298
+ 'angle': angles,
299
+ 'gamma': gammas,
300
+ 'trans': translations
301
+ }
302
+ def compute_for_render(self, coeffs):
303
+ """
304
+ Return:
305
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
306
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
307
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
308
+ Parameters:
309
+ coeffs -- torch.tensor, size (B, 257)
310
+ """
311
+ coef_dict = self.split_coeff(coeffs)
312
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
313
+ rotation = self.compute_rotation(coef_dict['angle'], device=self.device)
314
+
315
+
316
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
317
+ face_vertex = self.to_camera(face_shape_transformed)
318
+
319
+ face_proj = self.to_image(face_vertex)
320
+ landmark = self.get_landmarks(face_proj)
321
+
322
+ face_texture = self.compute_texture(coef_dict['tex'])
323
+ face_norm = self.compute_norm(face_shape)
324
+ face_norm_roted = face_norm @ rotation
325
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
326
+
327
+ return face_vertex, face_texture, face_color, landmark
328
+
329
+ def compute_face_vertex(self, id, exp, angle, trans):
330
+ """
331
+ Return:
332
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
333
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
334
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
335
+ Parameters:
336
+ coeffs -- torch.tensor, size (B, 257)
337
+ """
338
+ if not self.initialized:
339
+ self.to(id.device)
340
+ face_shape = self.compute_shape(id, exp)
341
+ rotation = self.compute_rotation(angle, device=self.device)
342
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
343
+ face_vertex = self.to_camera(face_shape_transformed)
344
+ return face_vertex
345
+
346
+ def compute_for_landmark_fit(self, id, exp, angles, trans, ret=None):
347
+ """
348
+ Return:
349
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
350
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
351
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
352
+ Parameters:
353
+ coeffs -- torch.tensor, size (B, 257)
354
+ """
355
+ face_shape = self.compute_key_shape(id, exp)
356
+ rotation = self.compute_rotation(angles, device=self.device)
357
+
358
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
359
+ face_vertex = self.to_camera(face_shape_transformed)
360
+
361
+ face_proj = self.to_image(face_vertex)
362
+ landmark = face_proj
363
+ return landmark
364
+
365
+ def compute_for_landmark_fit_nerf(self, id, exp, angles, trans, ret=None):
366
+ """
367
+ Return:
368
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
369
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
370
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
371
+ Parameters:
372
+ coeffs -- torch.tensor, size (B, 257)
373
+ """
374
+ face_shape = self.compute_key_shape(id, exp)
375
+ rotation = self.compute_rotation(angles, device=self.device)
376
+
377
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
378
+ face_vertex = face_shape_transformed # no to_camera
379
+
380
+ face_proj = self.to_image(face_vertex)
381
+ landmark = face_proj
382
+ return landmark
383
+
384
+ # def compute_for_landmark_fit(self, id, exp, angles, trans, ret={}):
385
+ # """
386
+ # Return:
387
+ # face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
388
+ # face_color -- torch.tensor, size (B, N, 3), in RGB order
389
+ # landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
390
+ # Parameters:
391
+ # coeffs -- torch.tensor, size (B, 257)
392
+ # """
393
+ # face_shape = self.compute_shape(id, exp)
394
+ # rotation = self.compute_rotation(angles)
395
+
396
+ # face_shape_transformed = self.transform(face_shape, rotation, trans)
397
+ # face_vertex = self.to_camera(face_shape_transformed)
398
+
399
+ # face_proj = self.to_image(face_vertex)
400
+ # landmark = self.get_landmarks(face_proj)
401
+ # return landmark
402
+
403
+ def compute_for_render_fit(self, id, exp, angles, trans, tex, gamma):
404
+ """
405
+ Return:
406
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
407
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
408
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
409
+ Parameters:
410
+ coeffs -- torch.tensor, size (B, 257)
411
+ """
412
+ face_shape = self.compute_shape(id, exp)
413
+ rotation = self.compute_rotation(angles, device=self.device)
414
+
415
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
416
+ face_vertex = self.to_camera(face_shape_transformed)
417
+
418
+ face_proj = self.to_image(face_vertex)
419
+ landmark = self.get_landmarks(face_proj)
420
+
421
+ face_texture = self.compute_texture(tex)
422
+ face_norm = self.compute_norm(face_shape)
423
+ face_norm_roted = face_norm @ rotation
424
+ face_color = self.compute_color(face_texture, face_norm_roted, gamma)
425
+
426
+ return face_color, face_vertex, landmark
deep_3drecon/ncc_code.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da54a620c0981d43cc9f30b3d8b3f5d4beb0ec0e27127a1ef3fb62ea50913609
3
+ size 428636
deep_3drecon/secc_renderer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+
6
+ from deep_3drecon.util.mesh_renderer import MeshRenderer
7
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
8
+
9
+
10
+ class SECC_Renderer(nn.Module):
11
+ def __init__(self, rasterize_size=None, device="cuda"):
12
+ super().__init__()
13
+ self.face_model = ParametricFaceModel('deep_3drecon/BFM')
14
+ self.fov = 2 * np.arctan(self.face_model.center / self.face_model.focal) * 180 / np.pi
15
+ self.znear = 5.
16
+ self.zfar = 15.
17
+ if rasterize_size is None:
18
+ rasterize_size = 2*self.face_model.center
19
+ self.face_renderer = MeshRenderer(rasterize_fov=self.fov, znear=self.znear, zfar=self.zfar, rasterize_size=rasterize_size, use_opengl=False).cuda()
20
+ face_feat = np.load("deep_3drecon/ncc_code.npy", allow_pickle=True)
21
+ self.face_feat = torch.tensor(face_feat.T).unsqueeze(0).to(device=device)
22
+
23
+ del_index_re = np.load('deep_3drecon/bfm_right_eye_faces.npy')
24
+ del_index_re = del_index_re - 1
25
+ del_index_le = np.load('deep_3drecon/bfm_left_eye_faces.npy')
26
+ del_index_le = del_index_le - 1
27
+ face_buf_list = []
28
+ for i in range(self.face_model.face_buf.shape[0]):
29
+ if i not in del_index_re and i not in del_index_le:
30
+ face_buf_list.append(self.face_model.face_buf[i])
31
+ face_buf_arr = np.array(face_buf_list)
32
+ self.face_buf = torch.tensor(face_buf_arr).to(device=device)
33
+
34
+ def forward(self, id, exp, euler, trans):
35
+ """
36
+ id, exp, euler, euler: [B, C] or [B, T, C]
37
+ return:
38
+ MASK: [B, 1, 512, 512], value[0. or 1.0], 1.0 denotes is face
39
+ SECC MAP: [B, 3, 512, 512], value[0~1]
40
+ if input is BTC format, return [B, C, T, H, W]
41
+ """
42
+ bs = id.shape[0]
43
+ is_btc_flag = id.ndim == 3
44
+ if is_btc_flag:
45
+ t = id.shape[1]
46
+ bs = bs * t
47
+ id, exp, euler, trans = id.reshape([bs,-1]), exp.reshape([bs,-1]), euler.reshape([bs,-1]), trans.reshape([bs,-1])
48
+
49
+ face_vertex = self.face_model.compute_face_vertex(id, exp, euler, trans)
50
+ face_mask, _, secc_face = self.face_renderer(
51
+ face_vertex, self.face_buf.unsqueeze(0).repeat([bs, 1, 1]), feat=self.face_feat.repeat([bs,1,1]))
52
+ secc_face = (secc_face - 0.5) / 0.5 # scale to -1~1
53
+
54
+ if is_btc_flag:
55
+ bs = bs // t
56
+ face_mask = rearrange(face_mask, "(n t) c h w -> n c t h w", n=bs, t=t)
57
+ secc_face = rearrange(secc_face, "(n t) c h w -> n c t h w", n=bs, t=t)
58
+ return face_mask, secc_face
59
+
60
+
61
+ if __name__ == '__main__':
62
+ import imageio
63
+
64
+ renderer = SECC_Renderer(rasterize_size=512)
65
+ ret = np.load("data/processed/videos/May/vid_coeff_fit.npy", allow_pickle=True).tolist()
66
+ idx = 6
67
+ id = torch.tensor(ret['id']).cuda()[idx:idx+1]
68
+ exp = torch.tensor(ret['exp']).cuda()[idx:idx+1]
69
+ angle = torch.tensor(ret['euler']).cuda()[idx:idx+1]
70
+ trans = torch.tensor(ret['trans']).cuda()[idx:idx+1]
71
+ mask, secc = renderer(id, exp, angle*0, trans*0) # [1, 1, 512, 512], [1, 3, 512, 512]
72
+
73
+ out_mask = mask[0].permute(1,2,0)
74
+ out_mask = (out_mask * 127.5 + 127.5).int().cpu().numpy()
75
+ imageio.imwrite("out_mask.png", out_mask)
76
+ out_img = secc[0].permute(1,2,0)
77
+ out_img = (out_img * 127.5 + 127.5).int().cpu().numpy()
78
+ imageio.imwrite("out_secc.png", out_img)
deep_3drecon/util/mesh_renderer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script is the differentiable renderer for Deep3DFaceRecon_pytorch
2
+ Attention, antialiasing step is missing in current version.
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import kornia
7
+ from kornia.geometry.camera import pixel2cam
8
+ import numpy as np
9
+ from typing import List
10
+ from scipy.io import loadmat
11
+ from torch import nn
12
+ import traceback
13
+
14
+ try:
15
+ import pytorch3d.ops
16
+ from pytorch3d.structures import Meshes
17
+ from pytorch3d.renderer import (
18
+ look_at_view_transform,
19
+ FoVPerspectiveCameras,
20
+ DirectionalLights,
21
+ RasterizationSettings,
22
+ MeshRenderer,
23
+ MeshRasterizer,
24
+ SoftPhongShader,
25
+ TexturesUV,
26
+ )
27
+ except:
28
+ traceback.print_exc()
29
+ # def ndc_projection(x=0.1, n=1.0, f=50.0):
30
+ # return np.array([[n/x, 0, 0, 0],
31
+ # [ 0, n/-x, 0, 0],
32
+ # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
33
+ # [ 0, 0, -1, 0]]).astype(np.float32)
34
+
35
+ class MeshRenderer(nn.Module):
36
+ def __init__(self,
37
+ rasterize_fov,
38
+ znear=0.1,
39
+ zfar=10,
40
+ rasterize_size=224,**args):
41
+ super(MeshRenderer, self).__init__()
42
+
43
+ # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
44
+ # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
45
+ # torch.diag(torch.tensor([1., -1, -1, 1])))
46
+ self.rasterize_size = rasterize_size
47
+ self.fov = rasterize_fov
48
+ self.znear = znear
49
+ self.zfar = zfar
50
+
51
+ self.rasterizer = None
52
+
53
+ def forward(self, vertex, tri, feat=None):
54
+ """
55
+ Return:
56
+ mask -- torch.tensor, size (B, 1, H, W)
57
+ depth -- torch.tensor, size (B, 1, H, W)
58
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
59
+
60
+ Parameters:
61
+ vertex -- torch.tensor, size (B, N, 3)
62
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
63
+ feat(optional) -- torch.tensor, size (B, N ,C), features
64
+ """
65
+ device = vertex.device
66
+ rsize = int(self.rasterize_size)
67
+ # ndc_proj = self.ndc_proj.to(device)
68
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
69
+ if vertex.shape[-1] == 3:
70
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
71
+ vertex[..., 0] = -vertex[..., 0]
72
+
73
+
74
+ # vertex_ndc = vertex @ ndc_proj.t()
75
+ if self.rasterizer is None:
76
+ self.rasterizer = MeshRasterizer()
77
+ print("create rasterizer on device cuda:%d"%device.index)
78
+
79
+ # ranges = None
80
+ # if isinstance(tri, List) or len(tri.shape) == 3:
81
+ # vum = vertex_ndc.shape[1]
82
+ # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
83
+ # fstartidx = torch.cumsum(fnum, dim=0) - fnum
84
+ # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
85
+ # for i in range(tri.shape[0]):
86
+ # tri[i] = tri[i] + i*vum
87
+ # vertex_ndc = torch.cat(vertex_ndc, dim=0)
88
+ # tri = torch.cat(tri, dim=0)
89
+
90
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
91
+ tri = tri.type(torch.int32).contiguous()
92
+
93
+ # rasterize
94
+ cameras = FoVPerspectiveCameras(
95
+ device=device,
96
+ fov=self.fov,
97
+ znear=self.znear,
98
+ zfar=self.zfar,
99
+ )
100
+
101
+ raster_settings = RasterizationSettings(
102
+ image_size=rsize
103
+ )
104
+
105
+ # print(vertex.shape, tri.shape)
106
+ if tri.ndim == 2:
107
+ tri = tri.unsqueeze(0)
108
+ mesh = Meshes(vertex.contiguous()[...,:3], tri)
109
+
110
+ fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
111
+ rast_out = fragments.pix_to_face.squeeze(-1)
112
+ depth = fragments.zbuf
113
+
114
+ # render depth
115
+ depth = depth.permute(0, 3, 1, 2)
116
+ mask = (rast_out > 0).float().unsqueeze(1)
117
+ depth = mask * depth
118
+
119
+
120
+ image = None
121
+ if feat is not None:
122
+ attributes = feat.reshape(-1,3)[mesh.faces_packed()]
123
+ image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
124
+ fragments.bary_coords,
125
+ attributes)
126
+ # print(image.shape)
127
+ image = image.squeeze(-2).permute(0, 3, 1, 2)
128
+ image = mask * image
129
+
130
+ return mask, depth, image
131
+
docs/prepare_env/install_guide-zh.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 环境配置
2
+ [English Doc](./install_guide.md)
3
+
4
+ 本文档陈述了搭建Real3D-Portrait Python环境的步骤,我们使用了Conda来管理依赖。
5
+
6
+ 以下配置已在 A100/V100 + CUDA11.7 中进行了验证。
7
+
8
+
9
+ # 1. 安装CUDA
10
+ 我们推荐安装CUDA `11.7`,其他CUDA版本(例如`10.2`、`12.x`)也可能有效。
11
+
12
+ # 2. 安装Python依赖
13
+ ```
14
+ cd <Real3DPortraitRoot>
15
+ source <CondaRoot>/bin/activate
16
+ conda create -n real3dportrait python=3.9
17
+ conda activate real3dportrait
18
+ conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
19
+
20
+ # 我们推荐安装torch2.0.1+cuda11.7.
21
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
22
+
23
+ # 从源代码安装,需要比较长的时间 (如果遇到各种time-out问题,建议使用代理)
24
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
25
+
26
+ # MMCV安装
27
+ pip install cython
28
+ pip install openmim==0.3.9
29
+ mim install mmcv==2.1.0 # 使用mim来加速mmcv安装
30
+
31
+ # 其他依赖项
32
+ pip install -r docs/prepare_env/requirements.txt -v
33
+
34
+ ```
35
+
docs/prepare_env/install_guide.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prepare the Environment
2
+ [中文文档](./install_guide-zh.md)
3
+
4
+ This guide is about building a python environment for Real3D-Portrait with Conda.
5
+
6
+ The following installation process is verified in A100/V100 + CUDA11.7.
7
+
8
+
9
+ # 1. Install CUDA
10
+ We recommend to install CUDA `11.7` (which is verified in various types of GPUs), but other CUDA versions (such as `10.2`, `12.x`) may also work well.
11
+
12
+ # 2. Install Python Packages
13
+ ```
14
+ cd <Real3DPortraitRoot>
15
+ source <CondaRoot>/bin/activate
16
+ conda create -n real3dportrait python=3.9
17
+ conda activate real3dportrait
18
+ conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
19
+
20
+ ### We recommend torch2.0.1+cuda11.7.
21
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
22
+
23
+ # Build from source, it may take a long time (Proxy is recommended if encountering the time-out problem)
24
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
25
+
26
+ # MMCV for some network structure
27
+ pip install cython
28
+ pip install openmim==0.3.9
29
+ mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
30
+
31
+ # other dependencies
32
+ pip install -r docs/prepare_env/requirements.txt -v
33
+
34
+ ```
docs/prepare_env/requirements.txt ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython
2
+ numpy # ==1.23.0
3
+ numba==0.56.4
4
+ pandas
5
+ transformers
6
+ scipy==1.11.1 # required by cal_fid. https://github.com/mseitzer/pytorch-fid/issues/103
7
+ scikit-learn
8
+ scikit-image
9
+ # tensorflow # you can flexible it, this is gpu version
10
+ tensorboard
11
+ tensorboardX
12
+ python_speech_features
13
+ resampy
14
+ opencv_python
15
+ face_alignment
16
+ matplotlib
17
+ configargparse
18
+ librosa==0.9.2
19
+ praat-parselmouth # ==0.4.3
20
+ trimesh
21
+ kornia==0.5.0
22
+ PyMCubes
23
+ lpips
24
+ setuptools # ==59.5.0
25
+ ffmpeg-python
26
+ moviepy
27
+ dearpygui
28
+ ninja
29
+ # pyaudio # for extract esperanto
30
+ mediapipe
31
+ protobuf
32
+ decord
33
+ soundfile
34
+ pillow
35
+ # torch # it's better to install torch with conda
36
+ av
37
+ timm
38
+ pretrainedmodels
39
+ faiss-cpu # for fast nearest camera pose retriveal
40
+ einops
41
+ # mmcv # use mim install is faster
42
+
43
+ # conditional flow matching
44
+ beartype
45
+ torchode
46
+ torchdiffeq
47
+
48
+ # tts
49
+ cython
50
+ textgrid
51
+ pyloudnorm
52
+ websocket-client
53
+ pyworld==0.2.1rc0
54
+ pypinyin==0.42.0
55
+ webrtcvad
56
+ torchshow
57
+
58
+ # cal spk sim
59
+ s3prl
60
+ fire
61
+
62
+ # cal LMD
63
+ dlib
64
+
65
+ # debug
66
+ ipykernel
67
+
68
+ # lama
69
+ hydra-core
70
+ pytorch_lightning
71
+ setproctitle
72
+
73
+ # Gradio GUI
74
+ httpx==0.23.3
75
+ gradio==4.16.0
inference/app_real3dportrait.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import argparse
3
+ import gradio as gr
4
+ from inference.real3d_infer import GeneFace2Infer
5
+ from utils.commons.hparams import hparams
6
+
7
+ class Inferer(GeneFace2Infer):
8
+ def infer_once_args(self, *args, **kargs):
9
+ assert len(kargs) == 0
10
+ keys = [
11
+ 'src_image_name',
12
+ 'drv_audio_name',
13
+ 'drv_pose_name',
14
+ 'bg_image_name',
15
+ 'blink_mode',
16
+ 'temperature',
17
+ 'mouth_amp',
18
+ 'out_mode',
19
+ 'map_to_init_pose',
20
+ 'hold_eye_opened',
21
+ 'head_torso_threshold',
22
+ 'a2m_ckpt',
23
+ 'head_ckpt',
24
+ 'torso_ckpt',
25
+ ]
26
+ inp = {}
27
+ out_name = None
28
+ info = ""
29
+
30
+ try: # try to catch errors and jump to return
31
+ for key_index in range(len(keys)):
32
+ key = keys[key_index]
33
+ inp[key] = args[key_index]
34
+ if '_name' in key:
35
+ inp[key] = inp[key] if inp[key] is not None else ''
36
+
37
+ if inp['src_image_name'] == '':
38
+ info = "Input Error: Source image is REQUIRED!"
39
+ raise ValueError
40
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
41
+ info = "Input Error: At least one of driving audio or video is REQUIRED!"
42
+ raise ValueError
43
+
44
+
45
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
46
+ inp['drv_audio_name'] = inp['drv_pose_name']
47
+ print("No audio input, we use driving pose video for video driving")
48
+
49
+ if inp['drv_pose_name'] == '':
50
+ inp['drv_pose_name'] = 'static'
51
+
52
+ reload_flag = False
53
+ if inp['a2m_ckpt'] != self.audio2secc_dir:
54
+ print("Changes of a2m_ckpt detected, reloading model")
55
+ reload_flag = True
56
+ if inp['head_ckpt'] != self.head_model_dir:
57
+ print("Changes of head_ckpt detected, reloading model")
58
+ reload_flag = True
59
+ if inp['torso_ckpt'] != self.torso_model_dir:
60
+ print("Changes of torso_ckpt detected, reloading model")
61
+ reload_flag = True
62
+
63
+ inp['out_name'] = ''
64
+ inp['seed'] = 42
65
+
66
+ print(f"infer inputs : {inp}")
67
+ if self.secc2video_hparams['htbsr_head_threshold'] != inp['head_torso_threshold']:
68
+ print("Changes of head_torso_threshold detected, reloading model")
69
+ reload_flag = True
70
+
71
+ try:
72
+ if reload_flag:
73
+ self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
74
+ except Exception as e:
75
+ content = f"{e}"
76
+ info = f"Reload ERROR: {content}"
77
+ raise ValueError
78
+ try:
79
+ out_name = self.infer_once(inp)
80
+ except Exception as e:
81
+ content = f"{e}"
82
+ info = f"Inference ERROR: {content}"
83
+ raise ValueError
84
+ except Exception as e:
85
+ if info == "": # unexpected errors
86
+ content = f"{e}"
87
+ info = f"WebUI ERROR: {content}"
88
+
89
+ # output part
90
+ if len(info) > 0 : # there is errors
91
+ print(info)
92
+ info_gr = gr.update(visible=True, value=info)
93
+ else: # no errors
94
+ info_gr = gr.update(visible=False, value=info)
95
+ if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
96
+ print(f"Succefully generated in {out_name}")
97
+ video_gr = gr.update(visible=True, value=out_name)
98
+ else:
99
+ print(f"Failed to generate")
100
+ video_gr = gr.update(visible=True, value=out_name)
101
+
102
+ return video_gr, info_gr
103
+
104
+ def toggle_audio_file(choice):
105
+ if choice == False:
106
+ return gr.update(visible=True), gr.update(visible=False)
107
+ else:
108
+ return gr.update(visible=False), gr.update(visible=True)
109
+
110
+ def ref_video_fn(path_of_ref_video):
111
+ if path_of_ref_video is not None:
112
+ return gr.update(value=True)
113
+ else:
114
+ return gr.update(value=False)
115
+
116
+ def real3dportrait_demo(
117
+ audio2secc_dir,
118
+ head_model_dir,
119
+ torso_model_dir,
120
+ device = 'cuda',
121
+ warpfn = None,
122
+ ):
123
+
124
+ sep_line = "-" * 40
125
+
126
+ infer_obj = Inferer(
127
+ audio2secc_dir=audio2secc_dir,
128
+ head_model_dir=head_model_dir,
129
+ torso_model_dir=torso_model_dir,
130
+ device=device,
131
+ )
132
+
133
+ print(sep_line)
134
+ print("Model loading is finished.")
135
+ print(sep_line)
136
+ with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
137
+ gr.Markdown("\
138
+ <div align='center'> <h2> Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight) </span> </h2> \
139
+ <a style='font-size:18px;color: #a0a0a0' href='https://arxiv.org/pdf/2401.08503.pdf'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
140
+ <a style='font-size:18px;color: #a0a0a0' href='https://real3dportrait.github.io/'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
141
+ <a style='font-size:18px;color: #a0a0a0' href='https://baidu.com'> Github </div>")
142
+
143
+ sources = None
144
+ with gr.Row():
145
+ with gr.Column(variant='panel'):
146
+ with gr.Tabs(elem_id="source_image"):
147
+ with gr.TabItem('Upload image'):
148
+ with gr.Row():
149
+ src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
150
+ with gr.Tabs(elem_id="driven_audio"):
151
+ with gr.TabItem('Upload audio'):
152
+ with gr.Column(variant='panel'):
153
+ drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
154
+ with gr.Tabs(elem_id="driven_pose"):
155
+ with gr.TabItem('Upload video'):
156
+ with gr.Column(variant='panel'):
157
+ drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
158
+ with gr.Tabs(elem_id="bg_image"):
159
+ with gr.TabItem('Upload image'):
160
+ with gr.Row():
161
+ bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
162
+
163
+
164
+ with gr.Column(variant='panel'):
165
+ with gr.Tabs(elem_id="checkbox"):
166
+ with gr.TabItem('General Settings'):
167
+ with gr.Column(variant='panel'):
168
+
169
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
170
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
171
+ mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
172
+ out_mode = gr.Radio(['final', 'concat_debug'], value='final', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
173
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose")
174
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
175
+ head_torso_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="head torso threshold", value=0.7, info='make it higher if you find ghosting around hair of output, default to be 0.7',)
176
+
177
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
178
+
179
+ with gr.Tabs(elem_id="genearted_video"):
180
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
181
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
182
+ with gr.Column(variant='panel'):
183
+ with gr.Tabs(elem_id="checkbox"):
184
+ with gr.TabItem('Checkpoints'):
185
+ with gr.Column(variant='panel'):
186
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
187
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
188
+ head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
189
+ torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
190
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
191
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
192
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
193
+
194
+
195
+ fn = infer_obj.infer_once_args
196
+ if warpfn:
197
+ fn = warpfn(fn)
198
+ submit.click(
199
+ fn=fn,
200
+ inputs=[
201
+ src_image_name,
202
+ drv_audio_name,
203
+ drv_pose_name,
204
+ bg_image_name,
205
+ blink_mode,
206
+ temperature,
207
+ mouth_amp,
208
+ out_mode,
209
+ map_to_init_pose,
210
+ hold_eye_opened,
211
+ head_torso_threshold,
212
+ audio2secc_dir,
213
+ head_model_dir,
214
+ torso_model_dir,
215
+ ],
216
+ outputs=[
217
+ gen_video,
218
+ info_box,
219
+ ],
220
+ )
221
+
222
+ print(sep_line)
223
+ print("Gradio page is constructed.")
224
+ print(sep_line)
225
+
226
+ return real3dportrait_interface
227
+
228
+ if __name__ == "__main__":
229
+ parser = argparse.ArgumentParser()
230
+ parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240126_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
231
+ parser.add_argument("--head_ckpt", type=str, default='')
232
+ parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240126_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
233
+ parser.add_argument("--port", type=int, default=None)
234
+ args = parser.parse_args()
235
+ demo = real3dportrait_demo(
236
+ audio2secc_dir=args.a2m_ckpt,
237
+ head_model_dir=args.head_ckpt,
238
+ torso_model_dir=args.torso_ckpt,
239
+ device='cuda:0',
240
+ warpfn=None,
241
+ )
242
+ demo.queue()
243
+ demo.launch(server_port=args.port)
244
+
inference/edit_secc.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from utils.commons.image_utils import dilate, erode
4
+ from sklearn.neighbors import NearestNeighbors
5
+ import copy
6
+ import numpy as np
7
+ from utils.commons.meters import Timer
8
+
9
+ def hold_eye_opened_for_secc(img):
10
+ img = img.permute(1,2,0).cpu().numpy()
11
+ img = ((img +1)/2*255).astype(np.uint)
12
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
13
+ face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
14
+ h,w = face_mask.shape
15
+ # get face and eye mask
16
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
17
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
18
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
19
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
20
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
21
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
22
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
23
+
24
+ opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
25
+ opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
26
+ coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
27
+
28
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
29
+ dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
30
+ # print(dists.max())
31
+ non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
32
+ non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
33
+ opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
34
+ opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
35
+
36
+ img[opened_eye_mask] = 0
37
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
38
+
39
+
40
+ # def hold_eye_opened_for_secc(img):
41
+ # img = copy.copy(img)
42
+ # eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
43
+ # eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
44
+ # img[eye_mask] = -1
45
+ # return img
46
+
47
+ def blink_eye_for_secc(img, close_eye_percent=0.5):
48
+ """
49
+ secc_img: [3,h,w], tensor, -1~1
50
+ """
51
+ img = img.permute(1,2,0).cpu().numpy()
52
+ img = ((img +1)/2*255).astype(np.uint)
53
+ assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
54
+ if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
55
+ img = copy.deepcopy(img)
56
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
57
+ h,w = face_mask.shape
58
+
59
+ # get face and eye mask
60
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
61
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
62
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
63
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
64
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
65
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
66
+ coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
67
+ coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
68
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
69
+ min_h = coarse_eye_xys[:, 0].min()
70
+ max_h = coarse_eye_xys[:, 0].max()
71
+ coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
72
+ left_min_w = coarse_left_eye_xys[:, 1].min()
73
+ left_max_w = coarse_left_eye_xys[:, 1].max()
74
+ coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
75
+ right_min_w = coarse_right_eye_xys[:, 1].min()
76
+ right_max_w = coarse_right_eye_xys[:, 1].max()
77
+
78
+ # 尽力较少需要考虑的face_xyz,以降低KNN的损耗
79
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
80
+ more_room = 4 # 过小会导致一些问题
81
+ left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
82
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
83
+ right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
84
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
85
+
86
+ around_eye_face_mask = face_mask & eye_prior_reigon
87
+ face_mask = around_eye_face_mask
88
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
89
+
90
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
91
+ dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
92
+ face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
93
+ face_pixs = face_pixs.reshape([-1])
94
+ face_xys_to_erode = face_xys[~face_pixs]
95
+ face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
96
+ eye_mask = (~ face_mask) & eye_prior_reigon
97
+
98
+ h_grid = np.mgrid[0:h, 0:w][0]
99
+ eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
100
+ eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
101
+
102
+ tmp_h_grid = h_grid.copy()
103
+ tmp_h_grid[~eye_mask] = 0
104
+ eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
105
+ tmp_h_grid = h_grid.copy()
106
+ tmp_h_grid[~eye_mask] = 99999
107
+ eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
108
+ tmp_h_grid = h_grid.copy()
109
+ tmp_h_grid[~eye_mask] = -99999
110
+ eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
111
+
112
+ eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
113
+ eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
114
+
115
+ tmp_h_grid = h_grid.copy()
116
+ tmp_h_grid[~eye_mask] = 99999
117
+ upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
118
+ tmp_h_grid = h_grid.copy()
119
+ tmp_h_grid[~eye_mask] = -99999
120
+ lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
121
+ eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
122
+
123
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
124
+ eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
125
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
126
+ distances, indices = nbrs.kneighbors(eye_blink_xys)
127
+ bg_fg_xys = face_xys[indices[:, 0]]
128
+ img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
129
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
130
+
131
+
132
+ if __name__ == '__main__':
133
+ import imageio
134
+ import tqdm
135
+ img = cv2.imread("assets/cano_secc.png")
136
+ img = img / 127.5 - 1
137
+ img = torch.FloatTensor(img).permute(2, 0, 1)
138
+ fps = 25
139
+ writer = imageio.get_writer('demo_blink.mp4', fps=fps)
140
+
141
+ for i in tqdm.trange(33):
142
+ blink_percent = 0.03 * i
143
+ with Timer("Blink", True):
144
+ out_img = blink_eye_for_secc(img, blink_percent)
145
+ out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
146
+ writer.append_data(out_img)
147
+ writer.close()
inference/infer_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import librosa
5
+ import numpy as np
6
+ import importlib
7
+ import tqdm
8
+ import copy
9
+ import cv2
10
+ from scipy.spatial.transform import Rotation
11
+
12
+
13
+ def load_img_to_512_hwc_array(img_name):
14
+ img = cv2.imread(img_name)
15
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
16
+ img = cv2.resize(img, (512, 512))
17
+ return img
18
+
19
+ def load_img_to_normalized_512_bchw_tensor(img_name):
20
+ img = load_img_to_512_hwc_array(img_name)
21
+ img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
22
+ return img
23
+
24
+ def mirror_index(index, len_seq):
25
+ """
26
+ get mirror index when indexing a sequence and the index is larger than len_pose
27
+ args:
28
+ index: int
29
+ len_pose: int
30
+ return:
31
+ mirror_index: int
32
+ """
33
+ turn = index // len_seq
34
+ res = index % len_seq
35
+ if turn % 2 == 0:
36
+ return res # forward indexing
37
+ else:
38
+ return len_seq - res - 1 # reverse indexing
39
+
40
+ def smooth_camera_sequence(camera, kernel_size=7):
41
+ """
42
+ smooth the camera trajectory (i.e., rotation & translation)...
43
+ args:
44
+ camera: [N, 25] or [N, 16]. np.ndarray
45
+ kernel_size: int
46
+ return:
47
+ smoothed_camera: [N, 25] or [N, 16]. np.ndarray
48
+ """
49
+ # poses: [N, 25], numpy array
50
+ N = camera.shape[0]
51
+ K = kernel_size // 2
52
+ poses = camera[:, :16].reshape([-1, 4, 4]).copy()
53
+ trans = poses[:, :3, 3].copy() # [N, 3]
54
+ rots = poses[:, :3, :3].copy() # [N, 3, 3]
55
+
56
+ for i in range(N):
57
+ start = max(0, i - K)
58
+ end = min(N, i + K + 1)
59
+ poses[i, :3, 3] = trans[start:end].mean(0)
60
+ try:
61
+ poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
62
+ except:
63
+ if i == 0:
64
+ poses[i, :3, :3] = rots[i]
65
+ else:
66
+ poses[i, :3, :3] = poses[i-1, :3, :3]
67
+ poses = poses.reshape([-1, 16])
68
+ camera[:, :16] = poses
69
+ return camera
70
+
71
+ def smooth_features_xd(in_tensor, kernel_size=7):
72
+ """
73
+ smooth the feature maps
74
+ args:
75
+ in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
76
+ kernel_size: int
77
+ return:
78
+ out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
79
+ """
80
+ t = in_tensor.shape[0]
81
+ ndim = in_tensor.ndim
82
+ pad = (kernel_size- 1)//2
83
+ in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
84
+ if ndim == 2: # tc
85
+ _,c = in_tensor.shape
86
+ in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
87
+ elif ndim == 4: # tchw
88
+ _,c,h,w = in_tensor.shape
89
+ in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
90
+ elif ndim == 5: # tcchw, like deformation
91
+ _,c1,c2, h,w = in_tensor.shape
92
+ in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
93
+ else: raise NotImplementedError()
94
+ avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
95
+ out_tensor = F.conv1d(in_tensor, avg_kernel)
96
+ if ndim == 2: # tc
97
+ return out_tensor.reshape([c,t]).permute(1,0)
98
+ elif ndim == 4: # tchw
99
+ return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
100
+ elif ndim == 5: # tcchw, like deformation
101
+ return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
102
+
103
+
104
+ def extract_audio_motion_from_ref_video(video_name):
105
+ def save_wav16k(audio_name):
106
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
107
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
108
+ wav16k_name = audio_name[:-4] + '_16k.wav'
109
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
110
+ os.system(extract_wav_cmd)
111
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
112
+ return wav16k_name
113
+
114
+ def get_f0( wav16k_name):
115
+ from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_from_fname,extract_f0_from_wav_and_mel
116
+ wav, mel = extract_mel_from_fname(wav16k_name)
117
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
118
+ f0 = f0.reshape([-1,1])
119
+ f0 = torch.tensor(f0)
120
+ return f0
121
+
122
+ def get_hubert(wav16k_name):
123
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
124
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
125
+ len_mel = hubert.shape[0]
126
+ x_multiply = 8
127
+ if len_mel % x_multiply == 0:
128
+ num_to_pad = 0
129
+ else:
130
+ num_to_pad = x_multiply - len_mel % x_multiply
131
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
132
+ hubert = torch.tensor(hubert)
133
+ return hubert
134
+
135
+ def get_exp(video_name):
136
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
137
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
138
+ exp = torch.tensor(drv_motion_coeff_dict['exp'])
139
+ return exp
140
+
141
+ wav16k_name = save_wav16k(video_name)
142
+ f0 = get_f0(wav16k_name)
143
+ hubert = get_hubert(wav16k_name)
144
+ os.system(f"rm {wav16k_name}")
145
+ exp = get_exp(video_name)
146
+ target_length = min(len(exp), len(hubert)//2, len(f0)//2)
147
+ exp = exp[:target_length]
148
+ f0 = f0[:target_length*2]
149
+ hubert = hubert[:target_length*2]
150
+ return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
inference/real3d_infer.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchshow as ts
5
+ import librosa
6
+ import random
7
+ import time
8
+ import numpy as np
9
+ import importlib
10
+ import tqdm
11
+ import copy
12
+ import cv2
13
+
14
+ # common utils
15
+ from utils.commons.hparams import hparams, set_hparams
16
+ from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
17
+ from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
18
+ # 3DMM-related utils
19
+ from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
20
+ from data_util.face3d_helper import Face3DHelper
21
+ from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
22
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
23
+ from deep_3drecon.secc_renderer import SECC_Renderer
24
+ from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
25
+ # Face Parsing
26
+ from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
27
+ from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
28
+ # other inference utils
29
+ from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
30
+ from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
31
+ from Real3DPortrait.inference.edit_secc import blink_eye_for_secc
32
+
33
+
34
+ def read_first_frame_from_a_video(vid_name):
35
+ frames = []
36
+ cap = cv2.VideoCapture(vid_name)
37
+ ret, frame_bgr = cap.read()
38
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
39
+ return frame_rgb
40
+
41
+ def analyze_weights_img(gen_output):
42
+ img_raw = gen_output['image_raw']
43
+ mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
44
+ mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
45
+ mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
46
+ mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
47
+ mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
48
+
49
+ img_raw_005_to_03 = img_raw.clone()
50
+ img_raw_005_to_03[~mask_005_to_03] = -1
51
+ img_raw_005_to_05 = img_raw.clone()
52
+ img_raw_005_to_05[~mask_005_to_05] = -1
53
+ img_raw_005_to_07 = img_raw.clone()
54
+ img_raw_005_to_07[~mask_005_to_07] = -1
55
+ img_raw_005_to_09 = img_raw.clone()
56
+ img_raw_005_to_09[~mask_005_to_09] = -1
57
+ img_raw_005_to_10 = img_raw.clone()
58
+ img_raw_005_to_10[~mask_005_to_10] = -1
59
+ ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
60
+
61
+ class GeneFace2Infer:
62
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
63
+ if device is None:
64
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
65
+ self.device = device
66
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
67
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
68
+ self.audio2secc_model.to(device).eval()
69
+ self.secc2video_model.to(device).eval()
70
+ self.seg_model = MediapipeSegmenter()
71
+ self.secc_renderer = SECC_Renderer(512)
72
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
73
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
74
+
75
+ def load_audio2secc(self, audio2secc_dir):
76
+ config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
77
+ set_hparams(f"{config_name}", print_hparams=False)
78
+ self.audio2secc_dir = audio2secc_dir
79
+ self.audio2secc_hparams = copy.deepcopy(hparams)
80
+ from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
81
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
82
+ audio_in_dim = 1024
83
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
84
+ audio_in_dim = 13
85
+
86
+ if 'icl' in hparams['task_cls']:
87
+ self.use_icl_audio2motion = True
88
+ model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
89
+ else:
90
+ self.use_icl_audio2motion = False
91
+ if hparams.get("use_pitch", False) is True:
92
+ model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
93
+ else:
94
+ model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
95
+ load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
96
+ return model
97
+
98
+ def load_secc2video(self, head_model_dir, torso_model_dir, inp):
99
+ if inp is None:
100
+ inp = {}
101
+ self.head_model_dir = head_model_dir
102
+ self.torso_model_dir = torso_model_dir
103
+ if torso_model_dir != '':
104
+ if torso_model_dir.endswith(".ckpt"):
105
+ set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
106
+ else:
107
+ set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
108
+ if inp.get('head_torso_threshold', None) is not None:
109
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
110
+ self.secc2video_hparams = copy.deepcopy(hparams)
111
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
112
+ model = OSAvatarSECC_Img2plane_Torso()
113
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=False)
114
+ if head_model_dir != '':
115
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
116
+ else:
117
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
118
+ if head_model_dir.endswith(".ckpt"):
119
+ set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
120
+ else:
121
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
122
+ if inp.get('head_torso_threshold', None) is not None:
123
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
124
+ self.secc2video_hparams = copy.deepcopy(hparams)
125
+ model = OSAvatarSECC_Img2plane()
126
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=False)
127
+ return model
128
+
129
+ def infer_once(self, inp):
130
+ self.inp = inp
131
+ samples = self.prepare_batch_from_inp(inp)
132
+ seed = inp['seed'] if inp['seed'] is not None else int(time.time())
133
+ random.seed(seed)
134
+ torch.manual_seed(seed)
135
+ np.random.seed(seed)
136
+ out_name = self.forward_system(samples, inp)
137
+ return out_name
138
+
139
+ def prepare_batch_from_inp(self, inp):
140
+ """
141
+ :param inp: {'audio_source_name': (str)}
142
+ :return: a dict that contains the condition feature of NeRF
143
+ """
144
+ sample = {}
145
+ # Process Driving Motion
146
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
147
+ self.save_wav16k(inp['drv_audio_name'])
148
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
149
+ hubert = self.get_hubert(self.wav16k_name)
150
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
151
+ hubert = self.get_mfcc(self.wav16k_name) / 100
152
+
153
+ f0 = self.get_f0(self.wav16k_name)
154
+ if f0.shape[0] > len(hubert):
155
+ f0 = f0[:len(hubert)]
156
+ else:
157
+ num_to_pad = len(hubert) - len(f0)
158
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
159
+ t_x = hubert.shape[0]
160
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
161
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
162
+ sample.update({
163
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
164
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
165
+ 'x_mask': x_mask.cuda(),
166
+ 'y_mask': y_mask.cuda(),
167
+ })
168
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
169
+ sample['audio'] = sample['hubert']
170
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
171
+ sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
172
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
173
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
174
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
175
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
176
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
177
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
178
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
179
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
180
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
181
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
182
+
183
+ # Face Parsing
184
+ image_name = inp['src_image_name']
185
+ if image_name.endswith(".mp4"):
186
+ img = read_first_frame_from_a_video(image_name)
187
+ image_name = inp['src_image_name'] = image_name[:-4] + '.png'
188
+ cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
189
+ sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
190
+ img = load_img_to_512_hwc_array(image_name)
191
+ segmap = self.seg_model._cal_seg_map(img)
192
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
193
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
194
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
195
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
196
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
197
+
198
+ if inp['bg_image_name'] == '':
199
+ bg_img = extract_background([img], [segmap], 'knn')
200
+ else:
201
+ bg_img = cv2.imread(inp['bg_image_name'])
202
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
203
+ bg_img = cv2.resize(bg_img, (512,512))
204
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
205
+
206
+ # 3DMM, get identity code and camera pose
207
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
208
+ assert coeff_dict is not None
209
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
210
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
211
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
212
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
213
+ sample['id'] = src_id.repeat([t_x//2,1])
214
+
215
+ # get the src_kp for torso model
216
+ src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
217
+ src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
218
+ sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
219
+
220
+ # get camera pose file
221
+ # random.seed(time.time())
222
+ inp['drv_pose_name'] = inp['drv_pose_name']
223
+ print(f"| To extract pose from {inp['drv_pose_name']}")
224
+
225
+ # extract camera pose
226
+ if inp['drv_pose_name'] == 'static':
227
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
228
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
229
+ else: # from file
230
+ if inp['drv_pose_name'].endswith('.mp4'):
231
+ # extract coeff from video
232
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
233
+ else:
234
+ # load from npy
235
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
236
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
237
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
238
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
239
+ len_pose = len(eulers)
240
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
241
+ sample['euler'] = eulers[index_lst]
242
+ sample['trans'] = trans[index_lst]
243
+
244
+ # fix the z axis
245
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
246
+
247
+ # mapping to the init pose
248
+ if inp.get("map_to_init_pose", 'False') == 'True':
249
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
250
+ sample['euler'] = sample['euler'] + diff_euler
251
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
252
+ sample['trans'] = sample['trans'] + diff_trans
253
+
254
+ # prepare camera
255
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
256
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
257
+ # smooth camera
258
+ camera_smo_ksize = 7
259
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
260
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
261
+ camera = torch.tensor(camera).cuda().float()
262
+ sample['camera'] = camera
263
+
264
+ return sample
265
+
266
+ @torch.no_grad()
267
+ def get_hubert(self, wav16k_name):
268
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
269
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
270
+ len_mel = hubert.shape[0]
271
+ x_multiply = 8
272
+ if len_mel % x_multiply == 0:
273
+ num_to_pad = 0
274
+ else:
275
+ num_to_pad = x_multiply - len_mel % x_multiply
276
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
277
+ return hubert
278
+
279
+ def get_mfcc(self, wav16k_name):
280
+ from utils.audio import librosa_wav2mfcc
281
+ hparams['fft_size'] = 1200
282
+ hparams['win_size'] = 1200
283
+ hparams['hop_size'] = 480
284
+ hparams['audio_num_mel_bins'] = 80
285
+ hparams['fmin'] = 80
286
+ hparams['fmax'] = 12000
287
+ hparams['audio_sample_rate'] = 24000
288
+ mfcc = librosa_wav2mfcc(wav16k_name,
289
+ fft_size=hparams['fft_size'],
290
+ hop_size=hparams['hop_size'],
291
+ win_length=hparams['win_size'],
292
+ num_mels=hparams['audio_num_mel_bins'],
293
+ fmin=hparams['fmin'],
294
+ fmax=hparams['fmax'],
295
+ sample_rate=hparams['audio_sample_rate'],
296
+ center=True)
297
+ mfcc = np.array(mfcc).reshape([-1, 13])
298
+ len_mel = mfcc.shape[0]
299
+ x_multiply = 8
300
+ if len_mel % x_multiply == 0:
301
+ num_to_pad = 0
302
+ else:
303
+ num_to_pad = x_multiply - len_mel % x_multiply
304
+ mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
305
+ return mfcc
306
+
307
+ @torch.no_grad()
308
+ def forward_audio2secc(self, batch, inp=None):
309
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
310
+ # audio-to-exp
311
+ ret = {}
312
+ pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'],)
313
+ print("| audio-to-motion finished")
314
+ if pred.shape[-1] == 144:
315
+ id = ret['pred'][0][:,:80]
316
+ exp = ret['pred'][0][:,80:]
317
+ else:
318
+ id = batch['id']
319
+ exp = ret['pred'][0]
320
+ if len(id) < len(exp): # happens when use ICL
321
+ id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
322
+ batch['id'] = id
323
+ batch['exp'] = exp
324
+ else:
325
+ drv_motion_coeff_dict = self.drv_motion_coeff_dict
326
+ batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
327
+
328
+ batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
329
+ if self.use_icl_audio2motion:
330
+ self.audio2secc_model.empty_context()
331
+ return batch
332
+
333
+ @torch.no_grad()
334
+ def get_driving_motion(self, id, exp, euler, trans, batch, inp):
335
+ zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
336
+ zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
337
+ # render the secc given the id,exp
338
+ with torch.no_grad():
339
+ chunk_size = 50
340
+ drv_secc_color_lst = []
341
+ num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
342
+ for i in tqdm.trange(num_iters, desc="rendering drv secc"):
343
+ torch.cuda.empty_cache()
344
+ face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
345
+ drv_secc_color_lst.append(drv_secc_color.cpu())
346
+ drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
347
+ _, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
348
+ _, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
349
+ batch['drv_secc'] = drv_secc_colors.cuda()
350
+ batch['src_secc'] = src_secc_color.cuda()
351
+ batch['cano_secc'] = cano_secc_color.cuda()
352
+
353
+ # blinking secc
354
+ if inp['blink_mode'] == 'period':
355
+ period = 5 # second
356
+
357
+ for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
358
+ if i % (25*period) == 0:
359
+ blink_dur_frames = random.randint(8, 12)
360
+ for offset in range(blink_dur_frames):
361
+ j = offset + i
362
+ if j >= len(drv_secc_colors)-1: break
363
+ def blink_percent_fn(t, T):
364
+ return -4/T**2 * t**2 + 4/T * t
365
+ blink_percent = blink_percent_fn(offset, blink_dur_frames)
366
+ secc = batch['drv_secc'][j]
367
+ out_secc = blink_eye_for_secc(secc, blink_percent)
368
+ out_secc = out_secc.cuda()
369
+ batch['drv_secc'][j] = out_secc
370
+
371
+ # get the drv_kp for torso model, using the transformed trajectory
372
+ drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
373
+
374
+ drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
375
+ batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
376
+ return batch
377
+
378
+ @torch.no_grad()
379
+ def forward_secc2video(self, batch, inp=None):
380
+ num_frames = len(batch['drv_secc'])
381
+ camera = batch['camera']
382
+ src_kps = batch['src_kp']
383
+ drv_kps = batch['drv_kp']
384
+ cano_secc_color = batch['cano_secc']
385
+ src_secc_color = batch['src_secc']
386
+ drv_secc_colors = batch['drv_secc']
387
+ ref_img_gt = batch['ref_gt_img']
388
+ ref_img_head = batch['ref_head_img']
389
+ ref_torso_img = batch['ref_torso_img']
390
+ bg_img = batch['bg_img']
391
+ segmap = batch['segmap']
392
+
393
+ # smooth torso drv_kp
394
+ torso_smo_ksize = 7
395
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
396
+
397
+ # forward renderer
398
+ img_raw_lst = []
399
+ img_lst = []
400
+ depth_img_lst = []
401
+ with torch.no_grad():
402
+ for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
403
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
404
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
405
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
406
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
407
+ 'kp_s': kp_src, 'kp_d': kp_drv}
408
+ if i == 0:
409
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
410
+ else:
411
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
412
+ img_lst.append(gen_output['image'])
413
+ img_raw_lst.append(gen_output['image_raw'])
414
+ depth_img_lst.append(gen_output['image_depth'])
415
+
416
+ # save demo video
417
+ depth_imgs = torch.cat(depth_img_lst)
418
+ imgs = torch.cat(img_lst)
419
+ imgs_raw = torch.cat(img_raw_lst)
420
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
421
+
422
+ if inp['out_mode'] == 'concat_debug':
423
+ secc_img = secc_img.cpu()
424
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
425
+
426
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
427
+ depth_img = depth_img.repeat([1,3,1,1])
428
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
429
+ depth_img = depth_img * 2 - 1
430
+ depth_img = depth_img.clamp(-1,1)
431
+
432
+ secc_img = secc_img / 127.5 - 1
433
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
434
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
435
+ elif inp['out_mode'] == 'final':
436
+ imgs = imgs.cpu()
437
+ elif inp['out_mode'] == 'debug':
438
+ raise NotImplementedError("to do: save separate videos")
439
+ imgs = imgs.clamp(-1,1)
440
+
441
+ import imageio
442
+ debug_name = 'demo.mp4'
443
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
444
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
445
+
446
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
447
+ writer.append_data(out_imgs[i])
448
+ writer.close()
449
+
450
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
451
+ try:
452
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
453
+ except: pass
454
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
455
+ os.system(f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}")
456
+ os.system(f"rm {debug_name}")
457
+ os.system(f"rm {self.wav16k_name}")
458
+ else:
459
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
460
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
461
+ os.system(f"mv {debug_name} {out_fname}")
462
+ print(f"Saved at {out_fname}")
463
+ return out_fname
464
+
465
+ @torch.no_grad()
466
+ def forward_system(self, batch, inp):
467
+ self.forward_audio2secc(batch, inp)
468
+ out_fname = self.forward_secc2video(batch, inp)
469
+ return out_fname
470
+
471
+ @classmethod
472
+ def example_run(cls, inp=None):
473
+ inp_tmp = {
474
+ 'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
475
+ 'src_image_name': 'data/raw/val_imgs/Macron.png'
476
+ }
477
+ if inp is not None:
478
+ inp_tmp.update(inp)
479
+ inp = inp_tmp
480
+
481
+ infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
482
+ infer_instance.infer_once(inp)
483
+
484
+ ##############
485
+ # IO-related
486
+ ##############
487
+ def save_wav16k(self, audio_name):
488
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
489
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
490
+ wav16k_name = audio_name[:-4] + '_16k.wav'
491
+ self.wav16k_name = wav16k_name
492
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
493
+ os.system(extract_wav_cmd)
494
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
495
+
496
+ def get_f0(self, wav16k_name):
497
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
498
+ wav, mel = extract_mel_from_fname(self.wav16k_name)
499
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
500
+ f0 = f0.reshape([-1,1])
501
+ return f0
502
+
503
+ if __name__ == '__main__':
504
+ import argparse, glob, tqdm
505
+ parser = argparse.ArgumentParser()
506
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240126_real3dportrait_orig/audio2secc_vae', type=str)
507
+ parser.add_argument("--head_ckpt", default='', type=str)
508
+ parser.add_argument("--torso_ckpt", default='checkpoints/240126_real3dportrait_orig/secc2plane_torso_orig', type=str)
509
+ parser.add_argument("--src_img", default='', type=str) # data/raw/examples/Macron.png
510
+ parser.add_argument("--bg_img", default='', type=str) # data/raw/examples/bg.png
511
+ parser.add_argument("--drv_aud", default='', type=str) # data/raw/examples/Obama_5s.wav
512
+ parser.add_argument("--drv_pose", default='static', type=str) # data/raw/examples/May_5s.mp4
513
+ parser.add_argument("--blink_mode", default='none', type=str) # none | period
514
+ parser.add_argument("--temperature", default=0.2, type=float) # sampling temperature in audio2motion, higher -> more diverse, less accurate
515
+ parser.add_argument("--mouth_amp", default=0.45, type=float) # scale of predicted mouth, enabled in audio-driven
516
+ parser.add_argument("--head_torso_threshold", default=0.9, type=float, help="0.1~1.0, turn up this value if the hair is translucent")
517
+ parser.add_argument("--out_name", default='') # output filename
518
+ parser.add_argument("--out_mode", default='final') # final: only output talking head video; concat_debug: talking head with internel features
519
+ parser.add_argument("--map_to_init_pose", default='True') # whether to map the pose of first frame to source image
520
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
521
+
522
+ args = parser.parse_args()
523
+
524
+ inp = {
525
+ 'a2m_ckpt': args.a2m_ckpt,
526
+ 'head_ckpt': args.head_ckpt,
527
+ 'torso_ckpt': args.torso_ckpt,
528
+ 'src_image_name': args.src_img,
529
+ 'bg_image_name': args.bg_img,
530
+ 'drv_audio_name': args.drv_aud,
531
+ 'drv_pose_name': args.drv_pose,
532
+ 'blink_mode': args.blink_mode,
533
+ 'temperature': args.temperature,
534
+ 'mouth_amp': args.mouth_amp,
535
+ 'out_name': args.out_name,
536
+ 'out_mode': args.out_mode,
537
+ 'map_to_init_pose': args.map_to_init_pose,
538
+ 'head_torso_threshold': args.head_torso_threshold,
539
+ 'seed': args.seed,
540
+ }
541
+
542
+ GeneFace2Infer.example_run(inp)
insta.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #conda create -n real3dportrait python=3.9
3
+ #conda activate real3dportrait
4
+ conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
5
+
6
+ ### We recommend torch2.0.1+cuda11.7.
7
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
8
+
9
+ # Build from source, it may take a long time (Proxy is recommended if encountering the time-out problem)
10
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
11
+
12
+ # MMCV for some network structure
13
+ pip install cython
14
+ pip install openmim==0.3.9
15
+ mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
16
+
17
+ # other dependencies
18
+ pip install -r docs/prepare_env/requirements.txt -v
modules/audio2motion/cnn_models.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def init_weights_func(m):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv1d") != -1:
9
+ torch.nn.init.xavier_uniform_(m.weight)
10
+
11
+
12
+ class LambdaLayer(nn.Module):
13
+ def __init__(self, lambd):
14
+ super(LambdaLayer, self).__init__()
15
+ self.lambd = lambd
16
+
17
+ def forward(self, x):
18
+ return self.lambd(x)
19
+
20
+
21
+ class LayerNorm(torch.nn.LayerNorm):
22
+ """Layer normalization module.
23
+ :param int nout: output dim size
24
+ :param int dim: dimension to be normalized
25
+ """
26
+
27
+ def __init__(self, nout, dim=-1, eps=1e-5):
28
+ """Construct an LayerNorm object."""
29
+ super(LayerNorm, self).__init__(nout, eps=eps)
30
+ self.dim = dim
31
+
32
+ def forward(self, x):
33
+ """Apply layer normalization.
34
+ :param torch.Tensor x: input tensor
35
+ :return: layer normalized tensor
36
+ :rtype torch.Tensor
37
+ """
38
+ if self.dim == -1:
39
+ return super(LayerNorm, self).forward(x)
40
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
41
+
42
+
43
+
44
+ class ResidualBlock(nn.Module):
45
+ """Implements conv->PReLU->norm n-times"""
46
+
47
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
48
+ c_multiple=2, ln_eps=1e-12, bias=False):
49
+ super(ResidualBlock, self).__init__()
50
+
51
+ if norm_type == 'bn':
52
+ norm_builder = lambda: nn.BatchNorm1d(channels)
53
+ elif norm_type == 'in':
54
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
55
+ elif norm_type == 'gn':
56
+ norm_builder = lambda: nn.GroupNorm(8, channels)
57
+ elif norm_type == 'ln':
58
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
59
+ else:
60
+ norm_builder = lambda: nn.Identity()
61
+
62
+ self.blocks = [
63
+ nn.Sequential(
64
+ norm_builder(),
65
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
66
+ padding=(dilation * (kernel_size - 1)) // 2, bias=bias),
67
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
68
+ nn.GELU(),
69
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, bias=bias),
70
+ )
71
+ for _ in range(n)
72
+ ]
73
+
74
+ self.blocks = nn.ModuleList(self.blocks)
75
+ self.dropout = dropout
76
+
77
+ def forward(self, x):
78
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
79
+ for b in self.blocks:
80
+ x_ = b(x)
81
+ if self.dropout > 0 and self.training:
82
+ x_ = F.dropout(x_, self.dropout, training=self.training)
83
+ x = x + x_
84
+ x = x * nonpadding
85
+ return x
86
+
87
+
88
+ class ConvBlocks(nn.Module):
89
+ """Decodes the expanded phoneme encoding into spectrograms"""
90
+
91
+ def __init__(self, channels, out_dims, dilations, kernel_size,
92
+ norm_type='ln', layers_in_block=2, c_multiple=2,
93
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, bias=False):
94
+ super(ConvBlocks, self).__init__()
95
+ self.is_BTC = is_BTC
96
+ self.res_blocks = nn.Sequential(
97
+ *[ResidualBlock(channels, kernel_size, d,
98
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
99
+ dropout=dropout, ln_eps=ln_eps, bias=bias)
100
+ for d in dilations],
101
+ )
102
+ if norm_type == 'bn':
103
+ norm = nn.BatchNorm1d(channels)
104
+ elif norm_type == 'in':
105
+ norm = nn.InstanceNorm1d(channels, affine=True)
106
+ elif norm_type == 'gn':
107
+ norm = nn.GroupNorm(8, channels)
108
+ elif norm_type == 'ln':
109
+ norm = LayerNorm(channels, dim=1, eps=ln_eps)
110
+ self.last_norm = norm
111
+ self.post_net1 = nn.Conv1d(channels, out_dims, kernel_size=3, padding=1, bias=bias)
112
+ if init_weights:
113
+ self.apply(init_weights_func)
114
+
115
+ def forward(self, x):
116
+ """
117
+
118
+ :param x: [B, T, H]
119
+ :return: [B, T, H]
120
+ """
121
+ if self.is_BTC:
122
+ x = x.transpose(1, 2) # [B, C, T]
123
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
124
+ x = self.res_blocks(x) * nonpadding
125
+ x = self.last_norm(x) * nonpadding
126
+ x = self.post_net1(x) * nonpadding
127
+ if self.is_BTC:
128
+ x = x.transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class SeqLevelConvolutionalModel(nn.Module):
133
+ def __init__(self, out_dim=64, dropout=0.5, audio_feat_type='ppg', backbone_type='unet', norm_type='bn'):
134
+ nn.Module.__init__(self)
135
+ self.audio_feat_type = audio_feat_type
136
+ if audio_feat_type == 'ppg':
137
+ self.audio_encoder = nn.Sequential(*[
138
+ nn.Conv1d(29, 48, 3, 1, 1, bias=False),
139
+ nn.BatchNorm1d(48) if norm_type=='bn' else LayerNorm(48, dim=1),
140
+ nn.GELU(),
141
+ nn.Conv1d(48, 48, 3, 1, 1, bias=False)
142
+ ])
143
+ self.energy_encoder = nn.Sequential(*[
144
+ nn.Conv1d(1, 16, 3, 1, 1, bias=False),
145
+ nn.BatchNorm1d(16) if norm_type=='bn' else LayerNorm(16, dim=1),
146
+ nn.GELU(),
147
+ nn.Conv1d(16, 16, 3, 1, 1, bias=False)
148
+ ])
149
+ elif audio_feat_type == 'mel':
150
+ self.mel_encoder = nn.Sequential(*[
151
+ nn.Conv1d(80, 64, 3, 1, 1, bias=False),
152
+ nn.BatchNorm1d(64) if norm_type=='bn' else LayerNorm(64, dim=1),
153
+ nn.GELU(),
154
+ nn.Conv1d(64, 64, 3, 1, 1, bias=False)
155
+ ])
156
+ else:
157
+ raise NotImplementedError("now only ppg or mel are supported!")
158
+
159
+ self.style_encoder = nn.Sequential(*[
160
+ nn.Linear(135, 256),
161
+ nn.GELU(),
162
+ nn.Linear(256, 256)
163
+ ])
164
+
165
+ if backbone_type == 'resnet':
166
+ self.backbone = ResNetBackbone()
167
+ elif backbone_type == 'unet':
168
+ self.backbone = UNetBackbone()
169
+ elif backbone_type == 'resblocks':
170
+ self.backbone = ResBlocksBackbone()
171
+ else:
172
+ raise NotImplementedError("Now only resnet and unet are supported!")
173
+
174
+ self.out_layer = nn.Sequential(
175
+ nn.BatchNorm1d(512) if norm_type=='bn' else LayerNorm(512, dim=1),
176
+ nn.Conv1d(512, 64, 3, 1, 1, bias=False),
177
+ nn.PReLU(),
178
+ nn.Conv1d(64, out_dim, 3, 1, 1, bias=False)
179
+ )
180
+ self.feat_dropout = nn.Dropout(p=dropout)
181
+
182
+ @property
183
+ def device(self):
184
+ return self.backbone.parameters().__next__().device
185
+
186
+ def forward(self, batch, ret, log_dict=None):
187
+ style, x_mask = batch['style'].to(self.device), batch['x_mask'].to(self.device)
188
+ style_feat = self.style_encoder(style) # [B,C=135] => [B,C=128]
189
+
190
+ if self.audio_feat_type == 'ppg':
191
+ audio, energy = batch['audio'].to(self.device), batch['energy'].to(self.device)
192
+ audio_feat = self.audio_encoder(audio.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=29] => [B,T,C=48]
193
+ energy_feat = self.energy_encoder(energy.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=1] => [B,T,C=16]
194
+ feat = torch.cat([audio_feat, energy_feat], dim=2) # [B,T,C=48+16]
195
+ elif self.audio_feat_type == 'mel':
196
+ mel = batch['mel'].to(self.device)
197
+ feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=64]
198
+
199
+ feat, x_mask = self.backbone(x=feat, sty=style_feat, x_mask=x_mask)
200
+
201
+ out = self.out_layer(feat.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T//2,C=256] => [B,T//2,C=64]
202
+
203
+ ret['pred'] = out
204
+ ret['mask'] = x_mask
205
+ return out
206
+
207
+
208
+ class ResBlocksBackbone(nn.Module):
209
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
210
+ super(ResBlocksBackbone,self).__init__()
211
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
212
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
213
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
214
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
215
+ self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
216
+
217
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
218
+ self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
219
+
220
+ self.dropout = nn.Dropout(p=p_dropout)
221
+
222
+ def forward(self, x, sty, x_mask=1.):
223
+ """
224
+ x: [B, T, C]
225
+ sty: [B, C=256]
226
+ x_mask: [B, T]
227
+ ret: [B, T/2, C]
228
+ """
229
+ x = x.transpose(1, 2) # [B, C, T]
230
+ x_mask = x_mask[:, None, :] # [B, 1, T]
231
+
232
+ x = self.resblocks_0(x) * x_mask # [B, C, T]
233
+
234
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
235
+ x = self.downsampler(x) * x_mask # [B, C, T/2]
236
+ x = self.resblocks_1(x) * x_mask # [B, C, T/2]
237
+ x = self.resblocks_2(x) * x_mask # [B, C, T/2]
238
+
239
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
240
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/2]
241
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/2]
242
+
243
+ x = self.resblocks_3(x) * x_mask # [B, C, T/2]
244
+ x = self.resblocks_4(x) * x_mask # [B, C, T/2]
245
+
246
+ x = x.transpose(1,2)
247
+ x_mask = x_mask.squeeze(1)
248
+ return x, x_mask
249
+
250
+
251
+
252
+ class ResNetBackbone(nn.Module):
253
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
254
+ super(ResNetBackbone,self).__init__()
255
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
256
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
257
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
258
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
259
+ self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
260
+
261
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
262
+ self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
263
+
264
+ self.dropout = nn.Dropout(p=p_dropout)
265
+
266
+ def forward(self, x, sty, x_mask=1.):
267
+ """
268
+ x: [B, T, C]
269
+ sty: [B, C=256]
270
+ x_mask: [B, T]
271
+ ret: [B, T/2, C]
272
+ """
273
+ x = x.transpose(1, 2) # [B, C, T]
274
+ x_mask = x_mask[:, None, :] # [B, 1, T]
275
+
276
+ x = self.resblocks_0(x) * x_mask # [B, C, T]
277
+
278
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
279
+ x = self.downsampler(x) * x_mask # [B, C, T/2]
280
+ x = self.resblocks_1(x) * x_mask # [B, C, T/2]
281
+
282
+ x_mask = self.downsampler(x_mask) # [B, 1, T/4]
283
+ x = self.downsampler(x) * x_mask # [B, C, T/4]
284
+ x = self.resblocks_2(x) * x_mask # [B, C, T/4]
285
+
286
+ x_mask = self.downsampler(x_mask) # [B, 1, T/8]
287
+ x = self.downsampler(x) * x_mask # [B, C, T/8]
288
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
289
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
290
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
291
+ x = self.resblocks_3(x) * x_mask # [B, C, T/8]
292
+
293
+ x_mask = self.upsampler(x_mask) # [B, 1, T/2]
294
+ x = self.upsampler(x) * x_mask # [B, C, T/2]
295
+ x = self.resblocks_4(x) * x_mask # [B, C, T/2]
296
+
297
+ x = x.transpose(1,2)
298
+ x_mask = x_mask.squeeze(1)
299
+ return x, x_mask
300
+
301
+
302
+ class UNetBackbone(nn.Module):
303
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
304
+ super(UNetBackbone, self).__init__()
305
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
306
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
307
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*8, kernel_size=3, norm_type=norm_type, is_BTC=False)
308
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
309
+ self.resblocks_4 = ConvBlocks(channels=768, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [768 = c3(512) + c2(256)]
310
+ self.resblocks_5 = ConvBlocks(channels=640, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [640 = c4(512) + c1(128)]
311
+
312
+ self.downsampler = nn.Upsample(scale_factor=0.5, mode='linear')
313
+ self.upsampler = nn.Upsample(scale_factor=2, mode='linear')
314
+ self.dropout = nn.Dropout(p=p_dropout)
315
+
316
+ def forward(self, x, sty, x_mask=1.):
317
+ """
318
+ x: [B, T, C]
319
+ sty: [B, C=256]
320
+ x_mask: [B, T]
321
+ ret: [B, T/2, C]
322
+ """
323
+ x = x.transpose(1, 2) # [B, C, T]
324
+ x_mask = x_mask[:, None, :] # [B, 1, T]
325
+
326
+ x0 = self.resblocks_0(x) * x_mask # [B, C, T]
327
+
328
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
329
+ x = self.downsampler(x0) * x_mask # [B, C, T/2]
330
+ x1 = self.resblocks_1(x) * x_mask # [B, C, T/2]
331
+
332
+ x_mask = self.downsampler(x_mask) # [B, 1, T/4]
333
+ x = self.downsampler(x1) * x_mask # [B, C, T/4]
334
+ x2 = self.resblocks_2(x) * x_mask # [B, C, T/4]
335
+
336
+ x_mask = self.downsampler(x_mask) # [B, 1, T/8]
337
+ x = self.downsampler(x2) * x_mask # [B, C, T/8]
338
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
339
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
340
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
341
+ x3 = self.resblocks_3(x) * x_mask # [B, C, T/8]
342
+
343
+ x_mask = self.upsampler(x_mask) # [B, 1, T/4]
344
+ x = self.upsampler(x3) * x_mask # [B, C, T/4]
345
+ x = torch.cat([x, self.dropout(x2.transpose(1,2)).transpose(1,2)], dim=1) #
346
+ x4 = self.resblocks_4(x) * x_mask # [B, C, T/4]
347
+
348
+ x_mask = self.upsampler(x_mask) # [B, 1, T/2]
349
+ x = self.upsampler(x4) * x_mask # [B, C, T/2]
350
+ x = torch.cat([x, self.dropout(x1.transpose(1,2)).transpose(1,2)], dim=1)
351
+ x5 = self.resblocks_5(x) * x_mask # [B, C, T/2]
352
+
353
+ x = x5.transpose(1,2)
354
+ x_mask = x_mask.squeeze(1)
355
+ return x, x_mask
356
+
357
+
358
+ if __name__ == '__main__':
359
+ pass
modules/audio2motion/flow_base.py ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ from scipy import linalg
3
+ from torch.nn import functional as F
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+ import modules.audio2motion.utils as utils
9
+ from modules.audio2motion.transformer_models import FFTBlocks
10
+ from utils.commons.hparams import hparams
11
+
12
+
13
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
14
+ n_channels_int = n_channels[0]
15
+ in_act = input_a + input_b
16
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
17
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
18
+ acts = t_act * s_act
19
+ return acts
20
+
21
+ class WN(torch.nn.Module):
22
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0,
23
+ p_dropout=0, share_cond_layers=False):
24
+ super(WN, self).__init__()
25
+ assert (kernel_size % 2 == 1)
26
+ assert (hidden_channels % 2 == 0)
27
+ self.hidden_channels = hidden_channels
28
+ self.kernel_size = kernel_size
29
+ self.dilation_rate = dilation_rate
30
+ self.n_layers = n_layers
31
+ self.gin_channels = gin_channels
32
+ self.p_dropout = p_dropout
33
+ self.share_cond_layers = share_cond_layers
34
+
35
+ self.in_layers = torch.nn.ModuleList()
36
+ self.res_skip_layers = torch.nn.ModuleList()
37
+
38
+ self.drop = nn.Dropout(p_dropout)
39
+
40
+ self.use_adapters = hparams.get("use_adapters", False)
41
+ if self.use_adapters:
42
+ self.adapter_layers = torch.nn.ModuleList()
43
+
44
+ if gin_channels != 0 and not share_cond_layers:
45
+ cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
46
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
47
+
48
+ for i in range(n_layers):
49
+ dilation = dilation_rate ** i
50
+ padding = int((kernel_size * dilation - dilation) / 2)
51
+ in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
52
+ dilation=dilation, padding=padding)
53
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
54
+ self.in_layers.append(in_layer)
55
+
56
+ # last one is not necessary
57
+ if i < n_layers - 1:
58
+ res_skip_channels = 2 * hidden_channels
59
+ else:
60
+ res_skip_channels = hidden_channels
61
+
62
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
63
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
64
+ self.res_skip_layers.append(res_skip_layer)
65
+
66
+ if self.use_adapters:
67
+ adapter_layer = MlpAdapter(in_out_dim=res_skip_channels, hid_dim=res_skip_channels//4)
68
+ self.adapter_layers.append(adapter_layer)
69
+
70
+ def forward(self, x, x_mask=None, g=None, **kwargs):
71
+ output = torch.zeros_like(x)
72
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
73
+
74
+ if g is not None and not self.share_cond_layers:
75
+ g = self.cond_layer(g)
76
+
77
+ for i in range(self.n_layers):
78
+ x_in = self.in_layers[i](x)
79
+ x_in = self.drop(x_in)
80
+ if g is not None:
81
+ cond_offset = i * 2 * self.hidden_channels
82
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
83
+ else:
84
+ g_l = torch.zeros_like(x_in)
85
+
86
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
87
+
88
+ res_skip_acts = self.res_skip_layers[i](acts)
89
+ if self.use_adapters:
90
+ res_skip_acts = self.adapter_layers[i](res_skip_acts.transpose(1,2)).transpose(1,2)
91
+ if i < self.n_layers - 1:
92
+ x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
93
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
94
+ else:
95
+ output = output + res_skip_acts
96
+ return output * x_mask
97
+
98
+ def remove_weight_norm(self):
99
+ def remove_weight_norm(m):
100
+ try:
101
+ nn.utils.remove_weight_norm(m)
102
+ except ValueError: # this module didn't have weight norm
103
+ return
104
+
105
+ self.apply(remove_weight_norm)
106
+
107
+ def enable_adapters(self):
108
+ if not self.use_adapters:
109
+ return
110
+ for adapter_layer in self.adapter_layers:
111
+ adapter_layer.enable()
112
+
113
+ def disable_adapters(self):
114
+ if not self.use_adapters:
115
+ return
116
+ for adapter_layer in self.adapter_layers:
117
+ adapter_layer.disable()
118
+
119
+ class Permute(nn.Module):
120
+ def __init__(self, *args):
121
+ super(Permute, self).__init__()
122
+ self.args = args
123
+
124
+ def forward(self, x):
125
+ return x.permute(self.args)
126
+
127
+
128
+ class LayerNorm(nn.Module):
129
+ def __init__(self, channels, eps=1e-4):
130
+ super().__init__()
131
+ self.channels = channels
132
+ self.eps = eps
133
+
134
+ self.gamma = nn.Parameter(torch.ones(channels))
135
+ self.beta = nn.Parameter(torch.zeros(channels))
136
+
137
+ def forward(self, x):
138
+ n_dims = len(x.shape)
139
+ mean = torch.mean(x, 1, keepdim=True)
140
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
141
+
142
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
143
+
144
+ shape = [1, -1] + [1] * (n_dims - 2)
145
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
146
+ return x
147
+
148
+
149
+ class ConvReluNorm(nn.Module):
150
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+ self.hidden_channels = hidden_channels
154
+ self.out_channels = out_channels
155
+ self.kernel_size = kernel_size
156
+ self.n_layers = n_layers
157
+ self.p_dropout = p_dropout
158
+ assert n_layers > 1, "Number of layers should be larger than 0."
159
+
160
+ self.conv_layers = nn.ModuleList()
161
+ self.norm_layers = nn.ModuleList()
162
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
163
+ self.norm_layers.append(LayerNorm(hidden_channels))
164
+ self.relu_drop = nn.Sequential(
165
+ nn.ReLU(),
166
+ nn.Dropout(p_dropout))
167
+ for _ in range(n_layers - 1):
168
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
169
+ self.norm_layers.append(LayerNorm(hidden_channels))
170
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
171
+ self.proj.weight.data.zero_()
172
+ self.proj.bias.data.zero_()
173
+
174
+ def forward(self, x, x_mask):
175
+ x_org = x
176
+ for i in range(self.n_layers):
177
+ x = self.conv_layers[i](x * x_mask)
178
+ x = self.norm_layers[i](x)
179
+ x = self.relu_drop(x)
180
+ x = x_org + self.proj(x)
181
+ return x * x_mask
182
+
183
+
184
+
185
+ class ActNorm(nn.Module):
186
+ def __init__(self, channels, ddi=False, **kwargs):
187
+ super().__init__()
188
+ self.channels = channels
189
+ self.initialized = not ddi
190
+
191
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
192
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
193
+
194
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
195
+ if x_mask is None:
196
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
197
+ x_len = torch.sum(x_mask, [1, 2])
198
+ if not self.initialized:
199
+ self.initialize(x, x_mask)
200
+ self.initialized = True
201
+
202
+ if reverse:
203
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
204
+ logdet = torch.sum(-self.logs) * x_len
205
+ else:
206
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
207
+ logdet = torch.sum(self.logs) * x_len # [b]
208
+ return z, logdet
209
+
210
+ def store_inverse(self):
211
+ pass
212
+
213
+ def set_ddi(self, ddi):
214
+ self.initialized = not ddi
215
+
216
+ def initialize(self, x, x_mask):
217
+ with torch.no_grad():
218
+ denom = torch.sum(x_mask, [0, 2])
219
+ m = torch.sum(x * x_mask, [0, 2]) / denom
220
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
221
+ v = m_sq - (m ** 2)
222
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
223
+
224
+ bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
225
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
226
+
227
+ self.bias.data.copy_(bias_init)
228
+ self.logs.data.copy_(logs_init)
229
+
230
+
231
+ class InvConvNear(nn.Module):
232
+ def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
233
+ super().__init__()
234
+ assert (n_split % 2 == 0)
235
+ self.channels = channels
236
+ self.n_split = n_split
237
+ self.n_sqz = n_sqz
238
+ self.no_jacobian = no_jacobian
239
+
240
+ w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
241
+ if torch.det(w_init) < 0:
242
+ w_init[:, 0] = -1 * w_init[:, 0]
243
+ self.lu = lu
244
+ if lu:
245
+ # LU decomposition can slightly speed up the inverse
246
+ np_p, np_l, np_u = linalg.lu(w_init)
247
+ np_s = np.diag(np_u)
248
+ np_sign_s = np.sign(np_s)
249
+ np_log_s = np.log(np.abs(np_s))
250
+ np_u = np.triu(np_u, k=1)
251
+ l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
252
+ eye = np.eye(*w_init.shape, dtype=float)
253
+
254
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
255
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
256
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
257
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
258
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
259
+ self.register_buffer('l_mask', torch.Tensor(l_mask))
260
+ self.register_buffer('eye', torch.Tensor(eye))
261
+ else:
262
+ self.weight = nn.Parameter(w_init)
263
+
264
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
265
+ b, c, t = x.size()
266
+ assert (c % self.n_split == 0)
267
+ if x_mask is None:
268
+ x_mask = 1
269
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
270
+ else:
271
+ x_len = torch.sum(x_mask, [1, 2])
272
+
273
+ x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
274
+ x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
275
+
276
+ if self.lu:
277
+ self.weight, log_s = self._get_weight()
278
+ logdet = log_s.sum()
279
+ logdet = logdet * (c / self.n_split) * x_len
280
+ else:
281
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
282
+
283
+ if reverse:
284
+ if hasattr(self, "weight_inv"):
285
+ weight = self.weight_inv
286
+ else:
287
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
288
+ logdet = -logdet
289
+ else:
290
+ weight = self.weight
291
+ if self.no_jacobian:
292
+ logdet = 0
293
+
294
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
295
+ z = F.conv2d(x, weight)
296
+
297
+ z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
298
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
299
+ return z, logdet
300
+
301
+ def _get_weight(self):
302
+ l, log_s, u = self.l, self.log_s, self.u
303
+ l = l * self.l_mask + self.eye
304
+ u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
305
+ weight = torch.matmul(self.p, torch.matmul(l, u))
306
+ return weight, log_s
307
+
308
+ def store_inverse(self):
309
+ weight, _ = self._get_weight()
310
+ self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
311
+
312
+
313
+ class InvConv(nn.Module):
314
+ def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
315
+ super().__init__()
316
+ w_shape = [channels, channels]
317
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
318
+ LU_decomposed = lu
319
+ if not LU_decomposed:
320
+ # Sample a random orthogonal matrix:
321
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
322
+ else:
323
+ np_p, np_l, np_u = linalg.lu(w_init)
324
+ np_s = np.diag(np_u)
325
+ np_sign_s = np.sign(np_s)
326
+ np_log_s = np.log(np.abs(np_s))
327
+ np_u = np.triu(np_u, k=1)
328
+ l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
329
+ eye = np.eye(*w_shape, dtype=float)
330
+
331
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
332
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
333
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
334
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
335
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
336
+ self.l_mask = torch.Tensor(l_mask)
337
+ self.eye = torch.Tensor(eye)
338
+ self.w_shape = w_shape
339
+ self.LU = LU_decomposed
340
+ self.weight = None
341
+
342
+ def get_weight(self, device, reverse):
343
+ w_shape = self.w_shape
344
+ self.p = self.p.to(device)
345
+ self.sign_s = self.sign_s.to(device)
346
+ self.l_mask = self.l_mask.to(device)
347
+ self.eye = self.eye.to(device)
348
+ l = self.l * self.l_mask + self.eye
349
+ u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
350
+ dlogdet = self.log_s.sum()
351
+ if not reverse:
352
+ w = torch.matmul(self.p, torch.matmul(l, u))
353
+ else:
354
+ l = torch.inverse(l.double()).float()
355
+ u = torch.inverse(u.double()).float()
356
+ w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
357
+ return w.view(w_shape[0], w_shape[1], 1), dlogdet
358
+
359
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
360
+ """
361
+ log-det = log|abs(|W|)| * pixels
362
+ """
363
+ b, c, t = x.size()
364
+ if x_mask is None:
365
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
366
+ else:
367
+ x_len = torch.sum(x_mask, [1, 2])
368
+ logdet = 0
369
+ if not reverse:
370
+ weight, dlogdet = self.get_weight(x.device, reverse)
371
+ z = F.conv1d(x, weight)
372
+ if logdet is not None:
373
+ logdet = logdet + dlogdet * x_len
374
+ return z, logdet
375
+ else:
376
+ if self.weight is None:
377
+ weight, dlogdet = self.get_weight(x.device, reverse)
378
+ else:
379
+ weight, dlogdet = self.weight, self.dlogdet
380
+ z = F.conv1d(x, weight)
381
+ if logdet is not None:
382
+ logdet = logdet - dlogdet * x_len
383
+ return z, logdet
384
+
385
+ def store_inverse(self):
386
+ self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
387
+
388
+
389
+ class Flip(nn.Module):
390
+ def forward(self, x, *args, reverse=False, **kwargs):
391
+ x = torch.flip(x, [1])
392
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
393
+ return x, logdet
394
+
395
+ def store_inverse(self):
396
+ pass
397
+
398
+
399
+ class CouplingBlock(nn.Module):
400
+ def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
401
+ gin_channels=0, p_dropout=0, sigmoid_scale=False,
402
+ share_cond_layers=False, wn=None):
403
+ super().__init__()
404
+ self.in_channels = in_channels
405
+ self.hidden_channels = hidden_channels
406
+ self.kernel_size = kernel_size
407
+ self.dilation_rate = dilation_rate
408
+ self.n_layers = n_layers
409
+ self.gin_channels = gin_channels
410
+ self.p_dropout = p_dropout
411
+ self.sigmoid_scale = sigmoid_scale
412
+
413
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
414
+ start = torch.nn.utils.weight_norm(start)
415
+ self.start = start
416
+ # Initializing last layer to 0 makes the affine coupling layers
417
+ # do nothing at first. This helps with training stability
418
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
419
+ end.weight.data.zero_()
420
+ end.bias.data.zero_()
421
+ self.end = end
422
+ self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels,
423
+ p_dropout, share_cond_layers)
424
+ if wn is not None:
425
+ self.wn.in_layers = wn.in_layers
426
+ self.wn.res_skip_layers = wn.res_skip_layers
427
+
428
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
429
+ if x_mask is None:
430
+ x_mask = 1
431
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
432
+
433
+ x = self.start(x_0) * x_mask
434
+ x = self.wn(x, x_mask, g)
435
+ out = self.end(x)
436
+
437
+ z_0 = x_0
438
+ m = out[:, :self.in_channels // 2, :]
439
+ logs = out[:, self.in_channels // 2:, :]
440
+ if self.sigmoid_scale:
441
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
442
+ if reverse:
443
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
444
+ logdet = torch.sum(-logs * x_mask, [1, 2])
445
+ else:
446
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
447
+ logdet = torch.sum(logs * x_mask, [1, 2])
448
+ z = torch.cat([z_0, z_1], 1)
449
+ return z, logdet
450
+
451
+ def store_inverse(self):
452
+ self.wn.remove_weight_norm()
453
+
454
+
455
+ class GlowFFTBlocks(FFTBlocks):
456
+ def __init__(self, hidden_size=128, gin_channels=256, num_layers=2, ffn_kernel_size=5,
457
+ dropout=None, num_heads=4, use_pos_embed=True, use_last_norm=True,
458
+ norm='ln', use_pos_embed_alpha=True):
459
+ super().__init__(hidden_size, num_layers, ffn_kernel_size, dropout, num_heads, use_pos_embed,
460
+ use_last_norm, norm, use_pos_embed_alpha)
461
+ self.inp_proj = nn.Conv1d(hidden_size + gin_channels, hidden_size, 1)
462
+
463
+ def forward(self, x, x_mask=None, g=None):
464
+ """
465
+ :param x: [B, C_x, T]
466
+ :param x_mask: [B, 1, T]
467
+ :param g: [B, C_g, T]
468
+ :return: [B, C_x, T]
469
+ """
470
+ if g is not None:
471
+ x = self.inp_proj(torch.cat([x, g], 1))
472
+ x = x.transpose(1, 2)
473
+ x = super(GlowFFTBlocks, self).forward(x, x_mask[:, 0] == 0)
474
+ x = x.transpose(1, 2)
475
+ return x
476
+
477
+
478
+ class TransformerCouplingBlock(nn.Module):
479
+ def __init__(self, in_channels, hidden_channels, n_layers,
480
+ gin_channels=0, p_dropout=0, sigmoid_scale=False):
481
+ super().__init__()
482
+ self.in_channels = in_channels
483
+ self.hidden_channels = hidden_channels
484
+ self.n_layers = n_layers
485
+ self.gin_channels = gin_channels
486
+ self.p_dropout = p_dropout
487
+ self.sigmoid_scale = sigmoid_scale
488
+
489
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
490
+ self.start = start
491
+ # Initializing last layer to 0 makes the affine coupling layers
492
+ # do nothing at first. This helps with training stability
493
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
494
+ end.weight.data.zero_()
495
+ end.bias.data.zero_()
496
+ self.end = end
497
+ self.fft_blocks = GlowFFTBlocks(
498
+ hidden_size=hidden_channels,
499
+ ffn_kernel_size=3,
500
+ gin_channels=gin_channels,
501
+ num_layers=n_layers)
502
+
503
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
504
+ if x_mask is None:
505
+ x_mask = 1
506
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
507
+
508
+ x = self.start(x_0) * x_mask
509
+ x = self.fft_blocks(x, x_mask, g)
510
+ out = self.end(x)
511
+
512
+ z_0 = x_0
513
+ m = out[:, :self.in_channels // 2, :]
514
+ logs = out[:, self.in_channels // 2:, :]
515
+ if self.sigmoid_scale:
516
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
517
+ if reverse:
518
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
519
+ logdet = torch.sum(-logs * x_mask, [1, 2])
520
+ else:
521
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
522
+ logdet = torch.sum(logs * x_mask, [1, 2])
523
+ z = torch.cat([z_0, z_1], 1)
524
+ return z, logdet
525
+
526
+ def store_inverse(self):
527
+ pass
528
+
529
+
530
+ class FreqFFTCouplingBlock(nn.Module):
531
+ def __init__(self, in_channels, hidden_channels, n_layers,
532
+ gin_channels=0, p_dropout=0, sigmoid_scale=False):
533
+ super().__init__()
534
+ self.in_channels = in_channels
535
+ self.hidden_channels = hidden_channels
536
+ self.n_layers = n_layers
537
+ self.gin_channels = gin_channels
538
+ self.p_dropout = p_dropout
539
+ self.sigmoid_scale = sigmoid_scale
540
+
541
+ hs = hidden_channels
542
+ stride = 8
543
+ self.start = torch.nn.Conv2d(3, hs, kernel_size=stride * 2,
544
+ stride=stride, padding=stride // 2)
545
+ end = nn.ConvTranspose2d(hs, 2, kernel_size=stride, stride=stride)
546
+ end.weight.data.zero_()
547
+ end.bias.data.zero_()
548
+ self.end = nn.Sequential(
549
+ nn.Conv2d(hs * 3, hs, 3, 1, 1),
550
+ nn.ReLU(),
551
+ nn.GroupNorm(4, hs),
552
+ nn.Conv2d(hs, hs, 3, 1, 1),
553
+ end
554
+ )
555
+ self.fft_v = FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers)
556
+ self.fft_h = nn.Sequential(
557
+ nn.Conv1d(hs, hs, 3, 1, 1),
558
+ nn.ReLU(),
559
+ nn.Conv1d(hs, hs, 3, 1, 1),
560
+ )
561
+ self.fft_g = nn.Sequential(
562
+ nn.Conv1d(
563
+ gin_channels - 160, hs, kernel_size=stride * 2, stride=stride, padding=stride // 2),
564
+ Permute(0, 2, 1),
565
+ FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers),
566
+ Permute(0, 2, 1),
567
+ )
568
+
569
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
570
+ g_, _ = utils.unsqueeze(g)
571
+ g_mel = g_[:, :80]
572
+ g_txt = g_[:, 80:]
573
+ g_mel, _ = utils.squeeze(g_mel)
574
+ g_txt, _ = utils.squeeze(g_txt) # [B, C, T]
575
+
576
+ if x_mask is None:
577
+ x_mask = 1
578
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
579
+ x = torch.stack([x_0, g_mel[:, :80], g_mel[:, 80:]], 1)
580
+ x = self.start(x) # [B, C, N_bins, T]
581
+ B, C, N_bins, T = x.shape
582
+
583
+ x_v = self.fft_v(x.permute(0, 3, 2, 1).reshape(B * T, N_bins, C))
584
+ x_v = x_v.reshape(B, T, N_bins, -1).permute(0, 3, 2, 1)
585
+ # x_v = x
586
+
587
+ x_h = self.fft_h(x.permute(0, 2, 1, 3).reshape(B * N_bins, C, T))
588
+ x_h = x_h.reshape(B, N_bins, -1, T).permute(0, 2, 1, 3)
589
+ # x_h = x
590
+
591
+ x_g = self.fft_g(g_txt)[:, :, None, :].repeat(1, 1, 10, 1)
592
+ x = torch.cat([x_v, x_h, x_g], 1)
593
+ out = self.end(x)
594
+
595
+ z_0 = x_0
596
+ m = out[:, 0]
597
+ logs = out[:, 1]
598
+ if self.sigmoid_scale:
599
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
600
+ if reverse:
601
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
602
+ logdet = torch.sum(-logs * x_mask, [1, 2])
603
+ else:
604
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
605
+ logdet = torch.sum(logs * x_mask, [1, 2])
606
+ z = torch.cat([z_0, z_1], 1)
607
+ return z, logdet
608
+
609
+ def store_inverse(self):
610
+ pass
611
+
612
+
613
+
614
+ class ResidualCouplingLayer(nn.Module):
615
+ def __init__(self,
616
+ channels,
617
+ hidden_channels,
618
+ kernel_size,
619
+ dilation_rate,
620
+ n_layers,
621
+ p_dropout=0,
622
+ gin_channels=0,
623
+ mean_only=False,
624
+ nn_type='wn'):
625
+ assert channels % 2 == 0, "channels should be divisible by 2"
626
+ super().__init__()
627
+ self.channels = channels
628
+ self.hidden_channels = hidden_channels
629
+ self.kernel_size = kernel_size
630
+ self.dilation_rate = dilation_rate
631
+ self.n_layers = n_layers
632
+ self.half_channels = channels // 2
633
+ self.mean_only = mean_only
634
+
635
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
636
+ if nn_type == 'wn':
637
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
638
+ gin_channels=gin_channels)
639
+ # elif nn_type == 'conv':
640
+ # self.enc = ConditionalConvBlocks(
641
+ # hidden_channels, gin_channels, hidden_channels, [1] * n_layers, kernel_size,
642
+ # layers_in_block=1, is_BTC=False)
643
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
644
+ self.post.weight.data.zero_()
645
+ self.post.bias.data.zero_()
646
+
647
+ def forward(self, x, x_mask, g=None, reverse=False):
648
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
649
+ h = self.pre(x0) * x_mask
650
+ h = self.enc(h, x_mask=x_mask, g=g)
651
+ stats = self.post(h) * x_mask
652
+ if not self.mean_only:
653
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
654
+ else:
655
+ m = stats
656
+ logs = torch.zeros_like(m)
657
+
658
+ if not reverse:
659
+ x1 = m + x1 * torch.exp(logs) * x_mask
660
+ x = torch.cat([x0, x1], 1)
661
+ logdet = torch.sum(logs, [1, 2])
662
+ return x, logdet
663
+ else:
664
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
665
+ x = torch.cat([x0, x1], 1)
666
+ logdet = -torch.sum(logs, [1, 2])
667
+ return x, logdet
668
+
669
+
670
+ class ResidualCouplingBlock(nn.Module):
671
+ def __init__(self,
672
+ channels,
673
+ hidden_channels,
674
+ kernel_size,
675
+ dilation_rate,
676
+ n_layers,
677
+ n_flows=4,
678
+ gin_channels=0,
679
+ nn_type='wn'):
680
+ super().__init__()
681
+ self.channels = channels
682
+ self.hidden_channels = hidden_channels
683
+ self.kernel_size = kernel_size
684
+ self.dilation_rate = dilation_rate
685
+ self.n_layers = n_layers
686
+ self.n_flows = n_flows
687
+ self.gin_channels = gin_channels
688
+
689
+ self.flows = nn.ModuleList()
690
+ for i in range(n_flows):
691
+ self.flows.append(
692
+ ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
693
+ gin_channels=gin_channels, mean_only=True, nn_type=nn_type))
694
+ self.flows.append(Flip())
695
+
696
+ def forward(self, x, x_mask, g=None, reverse=False):
697
+ if not reverse:
698
+ for flow in self.flows:
699
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
700
+ else:
701
+ for flow in reversed(self.flows):
702
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
703
+ return x
704
+
705
+
706
+ class Glow(nn.Module):
707
+ def __init__(self,
708
+ in_channels,
709
+ hidden_channels,
710
+ kernel_size,
711
+ dilation_rate,
712
+ n_blocks,
713
+ n_layers,
714
+ p_dropout=0.,
715
+ n_split=4,
716
+ n_sqz=2,
717
+ sigmoid_scale=False,
718
+ gin_channels=0,
719
+ inv_conv_type='near',
720
+ share_cond_layers=False,
721
+ share_wn_layers=0,
722
+ ):
723
+ super().__init__()
724
+ """
725
+ Note that regularization likes weight decay can leads to Nan error!
726
+ """
727
+
728
+ self.in_channels = in_channels
729
+ self.hidden_channels = hidden_channels
730
+ self.kernel_size = kernel_size
731
+ self.dilation_rate = dilation_rate
732
+ self.n_blocks = n_blocks
733
+ self.n_layers = n_layers
734
+ self.p_dropout = p_dropout
735
+ self.n_split = n_split
736
+ self.n_sqz = n_sqz
737
+ self.sigmoid_scale = sigmoid_scale
738
+ self.gin_channels = gin_channels
739
+ self.share_cond_layers = share_cond_layers
740
+ if gin_channels != 0 and share_cond_layers:
741
+ cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
742
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
743
+ wn = None
744
+ self.flows = nn.ModuleList()
745
+ for b in range(n_blocks):
746
+ self.flows.append(ActNorm(channels=in_channels * n_sqz))
747
+ if inv_conv_type == 'near':
748
+ self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
749
+ if inv_conv_type == 'invconv':
750
+ self.flows.append(InvConv(channels=in_channels * n_sqz))
751
+ if share_wn_layers > 0:
752
+ if b % share_wn_layers == 0:
753
+ wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
754
+ p_dropout, share_cond_layers)
755
+ self.flows.append(
756
+ CouplingBlock(
757
+ in_channels * n_sqz,
758
+ hidden_channels,
759
+ kernel_size=kernel_size,
760
+ dilation_rate=dilation_rate,
761
+ n_layers=n_layers,
762
+ gin_channels=gin_channels * n_sqz,
763
+ p_dropout=p_dropout,
764
+ sigmoid_scale=sigmoid_scale,
765
+ share_cond_layers=share_cond_layers,
766
+ wn=wn
767
+ ))
768
+
769
+ def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
770
+ """
771
+ x: [B,T,C]
772
+ x_mask: [B,T]
773
+ g: [B,T,C]
774
+ """
775
+ x = x.transpose(1,2)
776
+ x_mask = x_mask.unsqueeze(1)
777
+ if g is not None:
778
+ g = g.transpose(1,2)
779
+
780
+ logdet_tot = 0
781
+ if not reverse:
782
+ flows = self.flows
783
+ else:
784
+ flows = reversed(self.flows)
785
+ if return_hiddens:
786
+ hs = []
787
+ if self.n_sqz > 1:
788
+ x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
789
+ if g is not None:
790
+ g, _ = utils.squeeze(g, x_mask, self.n_sqz)
791
+ x_mask = x_mask_
792
+ if self.share_cond_layers and g is not None:
793
+ g = self.cond_layer(g)
794
+ for f in flows:
795
+ x, logdet = f(x, x_mask, g=g, reverse=reverse)
796
+ if return_hiddens:
797
+ hs.append(x)
798
+ logdet_tot += logdet
799
+ if self.n_sqz > 1:
800
+ x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
801
+
802
+ x = x.transpose(1,2)
803
+ if return_hiddens:
804
+ return x, logdet_tot, hs
805
+ return x, logdet_tot
806
+
807
+ def store_inverse(self):
808
+ def remove_weight_norm(m):
809
+ try:
810
+ nn.utils.remove_weight_norm(m)
811
+ except ValueError: # this module didn't have weight norm
812
+ return
813
+
814
+ self.apply(remove_weight_norm)
815
+ for f in self.flows:
816
+ f.store_inverse()
817
+
818
+
819
+ if __name__ == '__main__':
820
+ model = Glow(in_channels=64,
821
+ hidden_channels=128,
822
+ kernel_size=5,
823
+ dilation_rate=1,
824
+ n_blocks=12,
825
+ n_layers=4,
826
+ p_dropout=0.0,
827
+ n_split=4,
828
+ n_sqz=2,
829
+ sigmoid_scale=False,
830
+ gin_channels=80
831
+ )
832
+ exp = torch.rand([1,1440,64])
833
+ mel = torch.rand([1,1440,80])
834
+ x_mask = torch.ones([1,1440],dtype=torch.float32)
835
+ y, logdet = model(exp, x_mask,g=mel, reverse=False)
836
+ pred_exp, logdet = model(y, x_mask,g=mel, reverse=False)
837
+ # y: [b, t,c=64]
838
+ print(" ")
modules/audio2motion/multi_length_disc.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from modules.audio2motion.cnn_models import LambdaLayer
7
+
8
+
9
+ class Discriminator1DFactory(nn.Module):
10
+ def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
11
+ super(Discriminator1DFactory, self).__init__()
12
+ padding = kernel_size // 2
13
+
14
+ def discriminator_block(in_filters, out_filters, first=False):
15
+ """
16
+ Input: (B, c, T)
17
+ Output:(B, c, T//2)
18
+ """
19
+ conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
20
+ block = [
21
+ conv, # padding = kernel//2
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ nn.Dropout2d(0.25)
24
+ ]
25
+ if norm_type == 'bn' and not first:
26
+ block.append(nn.BatchNorm1d(out_filters, 0.8))
27
+ if norm_type == 'in' and not first:
28
+ block.append(nn.InstanceNorm1d(out_filters, affine=True))
29
+ block = nn.Sequential(*block)
30
+ return block
31
+
32
+ if time_length >= 8:
33
+ self.model = nn.ModuleList([
34
+ discriminator_block(in_dim, hidden_size, first=True),
35
+ discriminator_block(hidden_size, hidden_size),
36
+ discriminator_block(hidden_size, hidden_size),
37
+ ])
38
+ ds_size = time_length // (2 ** 3)
39
+ elif time_length == 3:
40
+ self.model = nn.ModuleList([
41
+ nn.Sequential(*[
42
+ nn.Conv1d(in_dim, hidden_size, 3, 1, 0),
43
+ nn.LeakyReLU(0.2, inplace=True),
44
+ nn.Dropout2d(0.25),
45
+ nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
46
+ nn.LeakyReLU(0.2, inplace=True),
47
+ nn.Dropout2d(0.25),
48
+ nn.BatchNorm1d(hidden_size, 0.8),
49
+ nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
50
+ nn.LeakyReLU(0.2, inplace=True),
51
+ nn.Dropout2d(0.25),
52
+ nn.BatchNorm1d(hidden_size, 0.8)
53
+ ])
54
+ ])
55
+ ds_size = 1
56
+ elif time_length == 1:
57
+ self.model = nn.ModuleList([
58
+ nn.Sequential(*[
59
+ nn.Linear(in_dim, hidden_size),
60
+ nn.LeakyReLU(0.2, inplace=True),
61
+ nn.Dropout2d(0.25),
62
+ nn.Linear(hidden_size, hidden_size),
63
+ nn.LeakyReLU(0.2, inplace=True),
64
+ nn.Dropout2d(0.25),
65
+ ])
66
+ ])
67
+ ds_size = 1
68
+
69
+ self.adv_layer = nn.Linear(hidden_size * ds_size, 1)
70
+
71
+ def forward(self, x):
72
+ """
73
+
74
+ :param x: [B, C, T]
75
+ :return: validity: [B, 1], h: List of hiddens
76
+ """
77
+ h = []
78
+ if x.shape[-1] == 1:
79
+ x = x.squeeze(-1)
80
+ for l in self.model:
81
+ x = l(x)
82
+ h.append(x)
83
+ if x.ndim == 2:
84
+ b, ct = x.shape
85
+ use_sigmoid = True
86
+ else:
87
+ b, c, t = x.shape
88
+ ct = c * t
89
+ use_sigmoid = False
90
+ x = x.view(b, ct)
91
+ validity = self.adv_layer(x) # [B, 1]
92
+ if use_sigmoid:
93
+ validity = torch.sigmoid(validity)
94
+ return validity, h
95
+
96
+
97
+ class CosineDiscriminator1DFactory(nn.Module):
98
+ def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
99
+ super().__init__()
100
+ padding = kernel_size // 2
101
+
102
+ def discriminator_block(in_filters, out_filters, first=False):
103
+ """
104
+ Input: (B, c, T)
105
+ Output:(B, c, T//2)
106
+ """
107
+ conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
108
+ block = [
109
+ conv, # padding = kernel//2
110
+ nn.LeakyReLU(0.2, inplace=True),
111
+ nn.Dropout2d(0.25)
112
+ ]
113
+ if norm_type == 'bn' and not first:
114
+ block.append(nn.BatchNorm1d(out_filters, 0.8))
115
+ if norm_type == 'in' and not first:
116
+ block.append(nn.InstanceNorm1d(out_filters, affine=True))
117
+ block = nn.Sequential(*block)
118
+ return block
119
+
120
+ self.model1 = nn.ModuleList([
121
+ discriminator_block(in_dim, hidden_size, first=True),
122
+ discriminator_block(hidden_size, hidden_size),
123
+ discriminator_block(hidden_size, hidden_size),
124
+ ])
125
+
126
+ self.model2 = nn.ModuleList([
127
+ discriminator_block(in_dim, hidden_size, first=True),
128
+ discriminator_block(hidden_size, hidden_size),
129
+ discriminator_block(hidden_size, hidden_size),
130
+ ])
131
+
132
+ self.relu = nn.ReLU()
133
+ def forward(self, x1, x2):
134
+ """
135
+
136
+ :param x1: [B, C, T]
137
+ :param x2: [B, C, T]
138
+ :return: validity: [B, 1], h: List of hiddens
139
+ """
140
+ h1, h2 = [], []
141
+ for l in self.model1:
142
+ x1 = l(x1)
143
+ h1.append(x1)
144
+ for l in self.model2:
145
+ x2 = l(x2)
146
+ h2.append(x1)
147
+ b,c,t = x1.shape
148
+ x1 = x1.view(b, c*t)
149
+ x2 = x2.view(b, c*t)
150
+ x1 = self.relu(x1)
151
+ x2 = self.relu(x2)
152
+ # x1 = F.normalize(x1, p=2, dim=1)
153
+ # x2 = F.normalize(x2, p=2, dim=1)
154
+ validity = F.cosine_similarity(x1, x2)
155
+ return validity, [h1,h2]
156
+
157
+
158
+ class MultiWindowDiscriminator(nn.Module):
159
+ def __init__(self, time_lengths, cond_dim=80, in_dim=64, kernel_size=3, hidden_size=128, disc_type='standard', norm_type='bn', reduction='sum'):
160
+ super(MultiWindowDiscriminator, self).__init__()
161
+ self.win_lengths = time_lengths
162
+ self.reduction = reduction
163
+ self.disc_type = disc_type
164
+
165
+ if cond_dim > 0:
166
+ self.use_cond = True
167
+ self.cond_proj_layers = nn.ModuleList()
168
+ self.in_proj_layers = nn.ModuleList()
169
+ else:
170
+ self.use_cond = False
171
+
172
+ self.conv_layers = nn.ModuleList()
173
+ for time_length in time_lengths:
174
+ conv_layer = [
175
+ Discriminator1DFactory(
176
+ time_length, kernel_size, in_dim=64, hidden_size=hidden_size,
177
+ norm_type=norm_type) if self.disc_type == 'standard'
178
+ else CosineDiscriminator1DFactory(time_length, kernel_size, in_dim=64,
179
+ hidden_size=hidden_size,norm_type=norm_type)
180
+ ]
181
+ self.conv_layers += conv_layer
182
+ if self.use_cond:
183
+ self.cond_proj_layers.append(nn.Linear(cond_dim, 64))
184
+ self.in_proj_layers.append(nn.Linear(in_dim, 64))
185
+
186
+ def clip(self, x, cond, x_len, win_length, start_frames=None):
187
+ '''Ramdom clip x to win_length.
188
+ Args:
189
+ x (tensor) : (B, T, C).
190
+ cond (tensor) : (B, T, H).
191
+ x_len (tensor) : (B,).
192
+ win_length (int): target clip length
193
+
194
+ Returns:
195
+ (tensor) : (B, c_in, win_length, n_bins).
196
+
197
+ '''
198
+ clip_from_same_frame = start_frames is None
199
+ T_start = 0
200
+ # T_end = x_len.max() - win_length
201
+ T_end = x_len.min() - win_length
202
+ if T_end < 0:
203
+ return None, None, start_frames
204
+ T_end = T_end.item()
205
+ if start_frames is None:
206
+ start_frame = np.random.randint(low=T_start, high=T_end + 1)
207
+ start_frames = [start_frame] * x.size(0)
208
+ else:
209
+ start_frame = start_frames[0]
210
+
211
+
212
+ if clip_from_same_frame:
213
+ x_batch = x[:, start_frame: start_frame + win_length, :]
214
+ c_batch = cond[:, start_frame: start_frame + win_length, :] if cond is not None else None
215
+ else:
216
+ x_lst = []
217
+ c_lst = []
218
+ for i, start_frame in enumerate(start_frames):
219
+ x_lst.append(x[i, start_frame: start_frame + win_length, :])
220
+ if cond is not None:
221
+ c_lst.append(cond[i, start_frame: start_frame + win_length, :])
222
+ x_batch = torch.stack(x_lst, dim=0)
223
+ if cond is None:
224
+ c_batch = None
225
+ else:
226
+ c_batch = torch.stack(c_lst, dim=0)
227
+ return x_batch, c_batch, start_frames
228
+
229
+ def forward(self, x, x_len, cond=None, start_frames_wins=None):
230
+ '''
231
+ Args:
232
+ x (tensor): input mel, (B, T, C).
233
+ x_length (tensor): len of per mel. (B,).
234
+
235
+ Returns:
236
+ tensor : (B).
237
+ '''
238
+ validity = []
239
+ if start_frames_wins is None:
240
+ start_frames_wins = [None] * len(self.conv_layers)
241
+ h = []
242
+ for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins):
243
+ x_clip, c_clip, start_frames = self.clip(
244
+ x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C)
245
+ start_frames_wins[i] = start_frames
246
+ if x_clip is None:
247
+ continue
248
+ if self.disc_type == 'standard':
249
+ if self.use_cond:
250
+ x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
251
+ c_clip = self.cond_proj_layers[i](c_clip)
252
+ x_clip = x_clip + c_clip
253
+ validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2))
254
+ elif self.disc_type == 'cosine':
255
+ assert self.use_cond is True
256
+ x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
257
+ c_clip = self.cond_proj_layers[i](c_clip)
258
+ validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2), c_clip.transpose(1,2))
259
+ else:
260
+ raise NotImplementedError
261
+
262
+ h += h_
263
+ validity.append(validity_pred)
264
+ if len(validity) != len(self.conv_layers):
265
+ return None, start_frames_wins, h
266
+ if self.reduction == 'sum':
267
+ validity = sum(validity) # [B]
268
+ elif self.reduction == 'stack':
269
+ validity = torch.stack(validity, -1) # [B, W_L]
270
+ return validity, start_frames_wins, h
271
+
272
+
273
+ class Discriminator(nn.Module):
274
+ def __init__(self, x_dim=80, y_dim=64, disc_type='standard',
275
+ uncond_disc=False, kernel_size=3, hidden_size=128, norm_type='bn', reduction='sum', time_lengths=(8,16,32)):
276
+ """_summary_
277
+
278
+ Args:
279
+ time_lengths (list, optional): the list of window size. Defaults to [32, 64, 128].
280
+ x_dim (int, optional): the dim of audio features. Defaults to 80, corresponding to mel-spec.
281
+ y_dim (int, optional): the dim of facial coeff. Defaults to 64, correspond to exp; other options can be 7(pose) or 71(exp+pose).
282
+ kernel (tuple, optional): _description_. Defaults to (3, 3).
283
+ c_in (int, optional): _description_. Defaults to 1.
284
+ hidden_size (int, optional): _description_. Defaults to 128.
285
+ norm_type (str, optional): _description_. Defaults to 'bn'.
286
+ reduction (str, optional): _description_. Defaults to 'sum'.
287
+ uncond_disc (bool, optional): _description_. Defaults to False.
288
+ """
289
+ super(Discriminator, self).__init__()
290
+ self.time_lengths = time_lengths
291
+ self.x_dim, self.y_dim = x_dim, y_dim
292
+ self.disc_type = disc_type
293
+ self.reduction = reduction
294
+ self.uncond_disc = uncond_disc
295
+
296
+ if uncond_disc:
297
+ self.x_dim = 0
298
+ cond_dim = 0
299
+
300
+ else:
301
+ cond_dim = 64
302
+ self.mel_encoder = nn.Sequential(*[
303
+ nn.Conv1d(self.x_dim, 64, 3, 1, 1, bias=False),
304
+ nn.BatchNorm1d(64),
305
+ nn.GELU(),
306
+ nn.Conv1d(64, cond_dim, 3, 1, 1, bias=False)
307
+ ])
308
+
309
+ self.disc = MultiWindowDiscriminator(
310
+ time_lengths=self.time_lengths,
311
+ in_dim=self.y_dim,
312
+ cond_dim=cond_dim,
313
+ kernel_size=kernel_size,
314
+ hidden_size=hidden_size, norm_type=norm_type,
315
+ reduction=reduction,
316
+ disc_type=disc_type
317
+ )
318
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
319
+
320
+ @property
321
+ def device(self):
322
+ return self.disc.parameters().__next__().device
323
+
324
+ def forward(self,x, batch, start_frames_wins=None):
325
+ """
326
+
327
+ :param x: [B, T, C]
328
+ :param cond: [B, T, cond_size]
329
+ :return:
330
+ """
331
+ x = x.to(self.device)
332
+ if not self.uncond_disc:
333
+ mel = self.downsampler(batch['mel'].to(self.device))
334
+ mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
335
+ else:
336
+ mel_feat = None
337
+ x_len = x.sum(-1).ne(0).int().sum([1])
338
+ disc_confidence, start_frames_wins, h = self.disc(x, x_len, mel_feat, start_frames_wins=start_frames_wins)
339
+ return disc_confidence
340
+
modules/audio2motion/transformer_base.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter
5
+ import torch.onnx.operators
6
+ import torch.nn.functional as F
7
+ from collections import defaultdict
8
+
9
+
10
+ def make_positions(tensor, padding_idx):
11
+ """Replace non-padding symbols with their position numbers.
12
+
13
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
14
+ """
15
+ # The series of casts and type-conversions here are carefully
16
+ # balanced to both work with ONNX export and XLA. In particular XLA
17
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
18
+ # how to handle the dtype kwarg in cumsum.
19
+ mask = tensor.ne(padding_idx).int()
20
+ return (
21
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
22
+ ).long() + padding_idx
23
+
24
+
25
+ def softmax(x, dim):
26
+ return F.softmax(x, dim=dim, dtype=torch.float32)
27
+
28
+
29
+ INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
30
+
31
+ def _get_full_incremental_state_key(module_instance, key):
32
+ module_name = module_instance.__class__.__name__
33
+
34
+ # assign a unique ID to each module instance, so that incremental state is
35
+ # not shared across module instances
36
+ if not hasattr(module_instance, '_instance_id'):
37
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
38
+ module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
39
+
40
+ return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
41
+
42
+
43
+
44
+ def get_incremental_state(module, incremental_state, key):
45
+ """Helper for getting incremental state for an nn.Module."""
46
+ full_key = _get_full_incremental_state_key(module, key)
47
+ if incremental_state is None or full_key not in incremental_state:
48
+ return None
49
+ return incremental_state[full_key]
50
+
51
+
52
+ def set_incremental_state(module, incremental_state, key, value):
53
+ """Helper for setting incremental state for an nn.Module."""
54
+ if incremental_state is not None:
55
+ full_key = _get_full_incremental_state_key(module, key)
56
+ incremental_state[full_key] = value
57
+
58
+
59
+
60
+ class Reshape(nn.Module):
61
+ def __init__(self, *args):
62
+ super(Reshape, self).__init__()
63
+ self.shape = args
64
+
65
+ def forward(self, x):
66
+ return x.view(self.shape)
67
+
68
+
69
+ class Permute(nn.Module):
70
+ def __init__(self, *args):
71
+ super(Permute, self).__init__()
72
+ self.args = args
73
+
74
+ def forward(self, x):
75
+ return x.permute(self.args)
76
+
77
+
78
+ class LinearNorm(torch.nn.Module):
79
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
80
+ super(LinearNorm, self).__init__()
81
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
82
+
83
+ torch.nn.init.xavier_uniform_(
84
+ self.linear_layer.weight,
85
+ gain=torch.nn.init.calculate_gain(w_init_gain))
86
+
87
+ def forward(self, x):
88
+ return self.linear_layer(x)
89
+
90
+
91
+ class ConvNorm(torch.nn.Module):
92
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
93
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
94
+ super(ConvNorm, self).__init__()
95
+ if padding is None:
96
+ assert (kernel_size % 2 == 1)
97
+ padding = int(dilation * (kernel_size - 1) / 2)
98
+
99
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
100
+ kernel_size=kernel_size, stride=stride,
101
+ padding=padding, dilation=dilation,
102
+ bias=bias)
103
+
104
+ torch.nn.init.xavier_uniform_(
105
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
106
+
107
+ def forward(self, signal):
108
+ conv_signal = self.conv(signal)
109
+ return conv_signal
110
+
111
+
112
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
113
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
114
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
115
+ if padding_idx is not None:
116
+ nn.init.constant_(m.weight[padding_idx], 0)
117
+ return m
118
+
119
+
120
+ class GroupNorm1DTBC(nn.GroupNorm):
121
+ def forward(self, input):
122
+ return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1)
123
+
124
+
125
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
126
+ if not export and torch.cuda.is_available():
127
+ try:
128
+ from apex.normalization import FusedLayerNorm
129
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
130
+ except ImportError:
131
+ pass
132
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
133
+
134
+
135
+ def Linear(in_features, out_features, bias=True):
136
+ m = nn.Linear(in_features, out_features, bias)
137
+ nn.init.xavier_uniform_(m.weight)
138
+ if bias:
139
+ nn.init.constant_(m.bias, 0.)
140
+ return m
141
+
142
+
143
+ class SinusoidalPositionalEmbedding(nn.Module):
144
+ """This module produces sinusoidal positional embeddings of any length.
145
+
146
+ Padding symbols are ignored.
147
+ """
148
+
149
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
150
+ super().__init__()
151
+ self.embedding_dim = embedding_dim
152
+ self.padding_idx = padding_idx
153
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
154
+ init_size,
155
+ embedding_dim,
156
+ padding_idx,
157
+ )
158
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
159
+
160
+ @staticmethod
161
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
162
+ """Build sinusoidal embeddings.
163
+
164
+ This matches the implementation in tensor2tensor, but differs slightly
165
+ from the description in Section 3.5 of "Attention Is All You Need".
166
+ """
167
+ half_dim = embedding_dim // 2
168
+ emb = math.log(10000) / (half_dim - 1)
169
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
170
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
171
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
172
+ if embedding_dim % 2 == 1:
173
+ # zero pad
174
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
175
+ if padding_idx is not None:
176
+ emb[padding_idx, :] = 0
177
+ return emb
178
+
179
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
180
+ """Input is expected to be of size [bsz x seqlen]."""
181
+ bsz, seq_len = input.shape[:2]
182
+ max_pos = self.padding_idx + 1 + seq_len
183
+ if self.weights is None or max_pos > self.weights.size(0):
184
+ # recompute/expand embeddings if needed
185
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
186
+ max_pos,
187
+ self.embedding_dim,
188
+ self.padding_idx,
189
+ )
190
+ self.weights = self.weights.to(self._float_tensor)
191
+
192
+ if incremental_state is not None:
193
+ # positions is the same for every token when decoding a single step
194
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
195
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
196
+
197
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
198
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
199
+
200
+ def max_positions(self):
201
+ """Maximum number of supported positions."""
202
+ return int(1e5) # an arbitrary large number
203
+
204
+
205
+ class ConvTBC(nn.Module):
206
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
207
+ super(ConvTBC, self).__init__()
208
+ self.in_channels = in_channels
209
+ self.out_channels = out_channels
210
+ self.kernel_size = kernel_size
211
+ self.padding = padding
212
+
213
+ self.weight = torch.nn.Parameter(torch.Tensor(
214
+ self.kernel_size, in_channels, out_channels))
215
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
216
+
217
+ def forward(self, input):
218
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
219
+
220
+
221
+ class MultiheadAttention(nn.Module):
222
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
223
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
224
+ encoder_decoder_attention=False):
225
+ super().__init__()
226
+ self.embed_dim = embed_dim
227
+ self.kdim = kdim if kdim is not None else embed_dim
228
+ self.vdim = vdim if vdim is not None else embed_dim
229
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
230
+
231
+ self.num_heads = num_heads
232
+ self.dropout = dropout
233
+ self.head_dim = embed_dim // num_heads
234
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
235
+ self.scaling = self.head_dim ** -0.5
236
+
237
+ self.self_attention = self_attention
238
+ self.encoder_decoder_attention = encoder_decoder_attention
239
+
240
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
241
+ 'value to be of the same size'
242
+
243
+ if self.qkv_same_dim:
244
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
245
+ else:
246
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
247
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
248
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
249
+
250
+ if bias:
251
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
252
+ else:
253
+ self.register_parameter('in_proj_bias', None)
254
+
255
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
256
+
257
+ if add_bias_kv:
258
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
259
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
260
+ else:
261
+ self.bias_k = self.bias_v = None
262
+
263
+ self.add_zero_attn = add_zero_attn
264
+
265
+ self.reset_parameters()
266
+
267
+ self.enable_torch_version = False
268
+ if hasattr(F, "multi_head_attention_forward"):
269
+ self.enable_torch_version = True
270
+ else:
271
+ self.enable_torch_version = False
272
+ self.last_attn_probs = None
273
+
274
+ def reset_parameters(self):
275
+ if self.qkv_same_dim:
276
+ nn.init.xavier_uniform_(self.in_proj_weight)
277
+ else:
278
+ nn.init.xavier_uniform_(self.k_proj_weight)
279
+ nn.init.xavier_uniform_(self.v_proj_weight)
280
+ nn.init.xavier_uniform_(self.q_proj_weight)
281
+
282
+ nn.init.xavier_uniform_(self.out_proj.weight)
283
+ if self.in_proj_bias is not None:
284
+ nn.init.constant_(self.in_proj_bias, 0.)
285
+ nn.init.constant_(self.out_proj.bias, 0.)
286
+ if self.bias_k is not None:
287
+ nn.init.xavier_normal_(self.bias_k)
288
+ if self.bias_v is not None:
289
+ nn.init.xavier_normal_(self.bias_v)
290
+
291
+ def forward(
292
+ self,
293
+ query, key, value,
294
+ key_padding_mask=None,
295
+ incremental_state=None,
296
+ need_weights=True,
297
+ static_kv=False,
298
+ attn_mask=None,
299
+ before_softmax=False,
300
+ need_head_weights=False,
301
+ enc_dec_attn_constraint_mask=None,
302
+ reset_attn_weight=None
303
+ ):
304
+ """Input shape: Time x Batch x Channel
305
+
306
+ Args:
307
+ key_padding_mask (ByteTensor, optional): mask to exclude
308
+ keys that are pads, of shape `(batch, src_len)`, where
309
+ padding elements are indicated by 1s.
310
+ need_weights (bool, optional): return the attention weights,
311
+ averaged over heads (default: False).
312
+ attn_mask (ByteTensor, optional): typically used to
313
+ implement causal attention, where the mask prevents the
314
+ attention from looking forward in time (default: None).
315
+ before_softmax (bool, optional): return the raw attention
316
+ weights and values before the attention softmax.
317
+ need_head_weights (bool, optional): return the attention
318
+ weights for each head. Implies *need_weights*. Default:
319
+ return the average attention weights over all heads.
320
+ """
321
+ if need_head_weights:
322
+ need_weights = True
323
+
324
+ tgt_len, bsz, embed_dim = query.size()
325
+ assert embed_dim == self.embed_dim
326
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
327
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
328
+ if self.qkv_same_dim:
329
+ return F.multi_head_attention_forward(query, key, value,
330
+ self.embed_dim, self.num_heads,
331
+ self.in_proj_weight,
332
+ self.in_proj_bias, self.bias_k, self.bias_v,
333
+ self.add_zero_attn, self.dropout,
334
+ self.out_proj.weight, self.out_proj.bias,
335
+ self.training, key_padding_mask, need_weights,
336
+ attn_mask)
337
+ else:
338
+ return F.multi_head_attention_forward(query, key, value,
339
+ self.embed_dim, self.num_heads,
340
+ torch.empty([0]),
341
+ self.in_proj_bias, self.bias_k, self.bias_v,
342
+ self.add_zero_attn, self.dropout,
343
+ self.out_proj.weight, self.out_proj.bias,
344
+ self.training, key_padding_mask, need_weights,
345
+ attn_mask, use_separate_proj_weight=True,
346
+ q_proj_weight=self.q_proj_weight,
347
+ k_proj_weight=self.k_proj_weight,
348
+ v_proj_weight=self.v_proj_weight)
349
+
350
+ if incremental_state is not None:
351
+ saved_state = self._get_input_buffer(incremental_state)
352
+ if 'prev_key' in saved_state:
353
+ # previous time steps are cached - no need to recompute
354
+ # key and value if they are static
355
+ if static_kv:
356
+ assert self.encoder_decoder_attention and not self.self_attention
357
+ key = value = None
358
+ else:
359
+ saved_state = None
360
+
361
+ if self.self_attention:
362
+ # self-attention
363
+ q, k, v = self.in_proj_qkv(query)
364
+ elif self.encoder_decoder_attention:
365
+ # encoder-decoder attention
366
+ q = self.in_proj_q(query)
367
+ if key is None:
368
+ assert value is None
369
+ k = v = None
370
+ else:
371
+ k = self.in_proj_k(key)
372
+ v = self.in_proj_v(key)
373
+
374
+ else:
375
+ q = self.in_proj_q(query)
376
+ k = self.in_proj_k(key)
377
+ v = self.in_proj_v(value)
378
+ q *= self.scaling
379
+
380
+ if self.bias_k is not None:
381
+ assert self.bias_v is not None
382
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
383
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
384
+ if attn_mask is not None:
385
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
386
+ if key_padding_mask is not None:
387
+ key_padding_mask = torch.cat(
388
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
389
+
390
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
391
+ if k is not None:
392
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
393
+ if v is not None:
394
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
395
+
396
+ if saved_state is not None:
397
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
398
+ if 'prev_key' in saved_state:
399
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
400
+ if static_kv:
401
+ k = prev_key
402
+ else:
403
+ k = torch.cat((prev_key, k), dim=1)
404
+ if 'prev_value' in saved_state:
405
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
406
+ if static_kv:
407
+ v = prev_value
408
+ else:
409
+ v = torch.cat((prev_value, v), dim=1)
410
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
411
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
412
+ if static_kv:
413
+ key_padding_mask = prev_key_padding_mask
414
+ else:
415
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
416
+
417
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
418
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
419
+ saved_state['prev_key_padding_mask'] = key_padding_mask
420
+
421
+ self._set_input_buffer(incremental_state, saved_state)
422
+
423
+ src_len = k.size(1)
424
+
425
+ # This is part of a workaround to get around fork/join parallelism
426
+ # not supporting Optional types.
427
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
428
+ key_padding_mask = None
429
+
430
+ if key_padding_mask is not None:
431
+ assert key_padding_mask.size(0) == bsz
432
+ assert key_padding_mask.size(1) == src_len
433
+
434
+ if self.add_zero_attn:
435
+ src_len += 1
436
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
437
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
438
+ if attn_mask is not None:
439
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
440
+ if key_padding_mask is not None:
441
+ key_padding_mask = torch.cat(
442
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
443
+
444
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
445
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
446
+
447
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
448
+
449
+ if attn_mask is not None:
450
+ if len(attn_mask.shape) == 2:
451
+ attn_mask = attn_mask.unsqueeze(0)
452
+ elif len(attn_mask.shape) == 3:
453
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
454
+ bsz * self.num_heads, tgt_len, src_len)
455
+ attn_weights = attn_weights + attn_mask
456
+
457
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
458
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
459
+ attn_weights = attn_weights.masked_fill(
460
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
461
+ -1e8,
462
+ )
463
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
464
+
465
+ if key_padding_mask is not None:
466
+ # don't attend to padding symbols
467
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
468
+ attn_weights = attn_weights.masked_fill(
469
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
470
+ -1e8,
471
+ )
472
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
473
+
474
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
475
+
476
+ if before_softmax:
477
+ return attn_weights, v
478
+
479
+ attn_weights_float = softmax(attn_weights, dim=-1)
480
+ attn_weights = attn_weights_float.type_as(attn_weights)
481
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
482
+
483
+ if reset_attn_weight is not None:
484
+ if reset_attn_weight:
485
+ self.last_attn_probs = attn_probs.detach()
486
+ else:
487
+ assert self.last_attn_probs is not None
488
+ attn_probs = self.last_attn_probs
489
+ attn = torch.bmm(attn_probs, v)
490
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
491
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
492
+ attn = self.out_proj(attn)
493
+
494
+ if need_weights:
495
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
496
+ if not need_head_weights:
497
+ # average attention weights over heads
498
+ attn_weights = attn_weights.mean(dim=0)
499
+ else:
500
+ attn_weights = None
501
+
502
+ return attn, (attn_weights, attn_logits)
503
+
504
+ def in_proj_qkv(self, query):
505
+ return self._in_proj(query).chunk(3, dim=-1)
506
+
507
+ def in_proj_q(self, query):
508
+ if self.qkv_same_dim:
509
+ return self._in_proj(query, end=self.embed_dim)
510
+ else:
511
+ bias = self.in_proj_bias
512
+ if bias is not None:
513
+ bias = bias[:self.embed_dim]
514
+ return F.linear(query, self.q_proj_weight, bias)
515
+
516
+ def in_proj_k(self, key):
517
+ if self.qkv_same_dim:
518
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
519
+ else:
520
+ weight = self.k_proj_weight
521
+ bias = self.in_proj_bias
522
+ if bias is not None:
523
+ bias = bias[self.embed_dim:2 * self.embed_dim]
524
+ return F.linear(key, weight, bias)
525
+
526
+ def in_proj_v(self, value):
527
+ if self.qkv_same_dim:
528
+ return self._in_proj(value, start=2 * self.embed_dim)
529
+ else:
530
+ weight = self.v_proj_weight
531
+ bias = self.in_proj_bias
532
+ if bias is not None:
533
+ bias = bias[2 * self.embed_dim:]
534
+ return F.linear(value, weight, bias)
535
+
536
+ def _in_proj(self, input, start=0, end=None):
537
+ weight = self.in_proj_weight
538
+ bias = self.in_proj_bias
539
+ weight = weight[start:end, :]
540
+ if bias is not None:
541
+ bias = bias[start:end]
542
+ return F.linear(input, weight, bias)
543
+
544
+ def _get_input_buffer(self, incremental_state):
545
+ return get_incremental_state(
546
+ self,
547
+ incremental_state,
548
+ 'attn_state',
549
+ ) or {}
550
+
551
+ def _set_input_buffer(self, incremental_state, buffer):
552
+ set_incremental_state(
553
+ self,
554
+ incremental_state,
555
+ 'attn_state',
556
+ buffer,
557
+ )
558
+
559
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
560
+ return attn_weights
561
+
562
+ def clear_buffer(self, incremental_state=None):
563
+ if incremental_state is not None:
564
+ saved_state = self._get_input_buffer(incremental_state)
565
+ if 'prev_key' in saved_state:
566
+ del saved_state['prev_key']
567
+ if 'prev_value' in saved_state:
568
+ del saved_state['prev_value']
569
+ self._set_input_buffer(incremental_state, saved_state)
570
+
571
+
572
+ class Swish(torch.autograd.Function):
573
+ @staticmethod
574
+ def forward(ctx, i):
575
+ result = i * torch.sigmoid(i)
576
+ ctx.save_for_backward(i)
577
+ return result
578
+
579
+ @staticmethod
580
+ def backward(ctx, grad_output):
581
+ i = ctx.saved_variables[0]
582
+ sigmoid_i = torch.sigmoid(i)
583
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
584
+
585
+
586
+ class CustomSwish(nn.Module):
587
+ def forward(self, input_tensor):
588
+ return Swish.apply(input_tensor)
589
+
590
+
591
+ class TransformerFFNLayer(nn.Module):
592
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
593
+ super().__init__()
594
+ self.kernel_size = kernel_size
595
+ self.dropout = dropout
596
+ self.act = act
597
+ if padding == 'SAME':
598
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
599
+ elif padding == 'LEFT':
600
+ self.ffn_1 = nn.Sequential(
601
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
602
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
603
+ )
604
+ self.ffn_2 = Linear(filter_size, hidden_size)
605
+ if self.act == 'swish':
606
+ self.swish_fn = CustomSwish()
607
+
608
+ def forward(self, x, incremental_state=None):
609
+ # x: T x B x C
610
+ if incremental_state is not None:
611
+ saved_state = self._get_input_buffer(incremental_state)
612
+ if 'prev_input' in saved_state:
613
+ prev_input = saved_state['prev_input']
614
+ x = torch.cat((prev_input, x), dim=0)
615
+ x = x[-self.kernel_size:]
616
+ saved_state['prev_input'] = x
617
+ self._set_input_buffer(incremental_state, saved_state)
618
+
619
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
620
+ x = x * self.kernel_size ** -0.5
621
+
622
+ if incremental_state is not None:
623
+ x = x[-1:]
624
+ if self.act == 'gelu':
625
+ x = F.gelu(x)
626
+ if self.act == 'relu':
627
+ x = F.relu(x)
628
+ if self.act == 'swish':
629
+ x = self.swish_fn(x)
630
+ x = F.dropout(x, self.dropout, training=self.training)
631
+ x = self.ffn_2(x)
632
+ return x
633
+
634
+ def _get_input_buffer(self, incremental_state):
635
+ return get_incremental_state(
636
+ self,
637
+ incremental_state,
638
+ 'f',
639
+ ) or {}
640
+
641
+ def _set_input_buffer(self, incremental_state, buffer):
642
+ set_incremental_state(
643
+ self,
644
+ incremental_state,
645
+ 'f',
646
+ buffer,
647
+ )
648
+
649
+ def clear_buffer(self, incremental_state):
650
+ if incremental_state is not None:
651
+ saved_state = self._get_input_buffer(incremental_state)
652
+ if 'prev_input' in saved_state:
653
+ del saved_state['prev_input']
654
+ self._set_input_buffer(incremental_state, saved_state)
655
+
656
+
657
+ class BatchNorm1dTBC(nn.Module):
658
+ def __init__(self, c):
659
+ super(BatchNorm1dTBC, self).__init__()
660
+ self.bn = nn.BatchNorm1d(c)
661
+
662
+ def forward(self, x):
663
+ """
664
+
665
+ :param x: [T, B, C]
666
+ :return: [T, B, C]
667
+ """
668
+ x = x.permute(1, 2, 0) # [B, C, T]
669
+ x = self.bn(x) # [B, C, T]
670
+ x = x.permute(2, 0, 1) # [T, B, C]
671
+ return x
672
+
673
+
674
+ class EncSALayer(nn.Module):
675
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
676
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
677
+ super().__init__()
678
+ self.c = c
679
+ self.dropout = dropout
680
+ self.num_heads = num_heads
681
+ if num_heads > 0:
682
+ if norm == 'ln':
683
+ self.layer_norm1 = LayerNorm(c)
684
+ elif norm == 'bn':
685
+ self.layer_norm1 = BatchNorm1dTBC(c)
686
+ elif norm == 'gn':
687
+ self.layer_norm1 = GroupNorm1DTBC(8, c)
688
+ self.self_attn = MultiheadAttention(
689
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
690
+ if norm == 'ln':
691
+ self.layer_norm2 = LayerNorm(c)
692
+ elif norm == 'bn':
693
+ self.layer_norm2 = BatchNorm1dTBC(c)
694
+ elif norm == 'gn':
695
+ self.layer_norm2 = GroupNorm1DTBC(8, c)
696
+ self.ffn = TransformerFFNLayer(
697
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
698
+
699
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
700
+ layer_norm_training = kwargs.get('layer_norm_training', None)
701
+ if layer_norm_training is not None:
702
+ self.layer_norm1.training = layer_norm_training
703
+ self.layer_norm2.training = layer_norm_training
704
+ if self.num_heads > 0:
705
+ residual = x
706
+ x = self.layer_norm1(x)
707
+ x, _, = self.self_attn(
708
+ query=x,
709
+ key=x,
710
+ value=x,
711
+ key_padding_mask=encoder_padding_mask
712
+ )
713
+ x = F.dropout(x, self.dropout, training=self.training)
714
+ x = residual + x
715
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
716
+
717
+ residual = x
718
+ x = self.layer_norm2(x)
719
+ x = self.ffn(x)
720
+ x = F.dropout(x, self.dropout, training=self.training)
721
+ x = residual + x
722
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
723
+ return x
724
+
725
+
726
+ class DecSALayer(nn.Module):
727
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
728
+ kernel_size=9, act='gelu', norm='ln'):
729
+ super().__init__()
730
+ self.c = c
731
+ self.dropout = dropout
732
+ if norm == 'ln':
733
+ self.layer_norm1 = LayerNorm(c)
734
+ elif norm == 'gn':
735
+ self.layer_norm1 = GroupNorm1DTBC(8, c)
736
+ self.self_attn = MultiheadAttention(
737
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
738
+ )
739
+ if norm == 'ln':
740
+ self.layer_norm2 = LayerNorm(c)
741
+ elif norm == 'gn':
742
+ self.layer_norm2 = GroupNorm1DTBC(8, c)
743
+ self.encoder_attn = MultiheadAttention(
744
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
745
+ )
746
+ if norm == 'ln':
747
+ self.layer_norm3 = LayerNorm(c)
748
+ elif norm == 'gn':
749
+ self.layer_norm3 = GroupNorm1DTBC(8, c)
750
+ self.ffn = TransformerFFNLayer(
751
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
752
+
753
+ def forward(
754
+ self,
755
+ x,
756
+ encoder_out=None,
757
+ encoder_padding_mask=None,
758
+ incremental_state=None,
759
+ self_attn_mask=None,
760
+ self_attn_padding_mask=None,
761
+ attn_out=None,
762
+ reset_attn_weight=None,
763
+ **kwargs,
764
+ ):
765
+ layer_norm_training = kwargs.get('layer_norm_training', None)
766
+ if layer_norm_training is not None:
767
+ self.layer_norm1.training = layer_norm_training
768
+ self.layer_norm2.training = layer_norm_training
769
+ self.layer_norm3.training = layer_norm_training
770
+ residual = x
771
+ x = self.layer_norm1(x)
772
+ x, _ = self.self_attn(
773
+ query=x,
774
+ key=x,
775
+ value=x,
776
+ key_padding_mask=self_attn_padding_mask,
777
+ incremental_state=incremental_state,
778
+ attn_mask=self_attn_mask
779
+ )
780
+ x = F.dropout(x, self.dropout, training=self.training)
781
+ x = residual + x
782
+
783
+ attn_logits = None
784
+ if encoder_out is not None or attn_out is not None:
785
+ residual = x
786
+ x = self.layer_norm2(x)
787
+ if encoder_out is not None:
788
+ x, attn = self.encoder_attn(
789
+ query=x,
790
+ key=encoder_out,
791
+ value=encoder_out,
792
+ key_padding_mask=encoder_padding_mask,
793
+ incremental_state=incremental_state,
794
+ static_kv=True,
795
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
796
+ 'enc_dec_attn_constraint_mask'),
797
+ reset_attn_weight=reset_attn_weight
798
+ )
799
+ attn_logits = attn[1]
800
+ elif attn_out is not None:
801
+ x = self.encoder_attn.in_proj_v(attn_out)
802
+ if encoder_out is not None or attn_out is not None:
803
+ x = F.dropout(x, self.dropout, training=self.training)
804
+ x = residual + x
805
+
806
+ residual = x
807
+ x = self.layer_norm3(x)
808
+ x = self.ffn(x, incremental_state=incremental_state)
809
+ x = F.dropout(x, self.dropout, training=self.training)
810
+ x = residual + x
811
+ return x, attn_logits
812
+
813
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
814
+ self.encoder_attn.clear_buffer(incremental_state)
815
+ self.ffn.clear_buffer(incremental_state)
816
+
817
+ def set_buffer(self, name, tensor, incremental_state):
818
+ return set_incremental_state(self, incremental_state, name, tensor)
819
+
820
+
821
+ class ConvBlock(nn.Module):
822
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
823
+ super().__init__()
824
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
825
+ self.norm = norm
826
+ if self.norm == 'bn':
827
+ self.norm = nn.BatchNorm1d(n_chans)
828
+ elif self.norm == 'in':
829
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
830
+ elif self.norm == 'gn':
831
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
832
+ elif self.norm == 'ln':
833
+ self.norm = LayerNorm(n_chans // 16, n_chans)
834
+ elif self.norm == 'wn':
835
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
836
+ self.dropout = nn.Dropout(dropout)
837
+ self.relu = nn.ReLU()
838
+
839
+ def forward(self, x):
840
+ """
841
+
842
+ :param x: [B, C, T]
843
+ :return: [B, C, T]
844
+ """
845
+ x = self.conv(x)
846
+ if not isinstance(self.norm, str):
847
+ if self.norm == 'none':
848
+ pass
849
+ elif self.norm == 'ln':
850
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
851
+ else:
852
+ x = self.norm(x)
853
+ x = self.relu(x)
854
+ x = self.dropout(x)
855
+ return x
856
+
857
+
858
+ class ConvStacks(nn.Module):
859
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
860
+ dropout=0, strides=None, res=True):
861
+ super().__init__()
862
+ self.conv = torch.nn.ModuleList()
863
+ self.kernel_size = kernel_size
864
+ self.res = res
865
+ self.in_proj = Linear(idim, n_chans)
866
+ if strides is None:
867
+ strides = [1] * n_layers
868
+ else:
869
+ assert len(strides) == n_layers
870
+ for idx in range(n_layers):
871
+ self.conv.append(ConvBlock(
872
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
873
+ self.out_proj = Linear(n_chans, odim)
874
+
875
+ def forward(self, x, return_hiddens=False):
876
+ """
877
+
878
+ :param x: [B, T, H]
879
+ :return: [B, T, H]
880
+ """
881
+ x = self.in_proj(x)
882
+ x = x.transpose(1, -1) # (B, idim, Tmax)
883
+ hiddens = []
884
+ for f in self.conv:
885
+ x_ = f(x)
886
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
887
+ hiddens.append(x)
888
+ x = x.transpose(1, -1)
889
+ x = self.out_proj(x) # (B, Tmax, H)
890
+ if return_hiddens:
891
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
892
+ return x, hiddens
893
+ return x
894
+
895
+
896
+ class ConvGlobalStacks(nn.Module):
897
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0,
898
+ strides=[2, 2, 2, 2, 2]):
899
+ super().__init__()
900
+ self.conv = torch.nn.ModuleList()
901
+ self.pooling = torch.nn.ModuleList()
902
+ self.kernel_size = kernel_size
903
+ self.in_proj = Linear(idim, n_chans)
904
+ for idx in range(n_layers):
905
+ self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx],
906
+ norm=norm, dropout=dropout))
907
+ self.pooling.append(nn.MaxPool1d(strides[idx]))
908
+ self.out_proj = Linear(n_chans, odim)
909
+
910
+ def forward(self, x):
911
+ """
912
+
913
+ :param x: [B, T, H]
914
+ :return: [B, T, H]
915
+ """
916
+ x = self.in_proj(x)
917
+ x = x.transpose(1, -1) # (B, idim, Tmax)
918
+ for f, p in zip(self.conv, self.pooling):
919
+ x = f(x) # (B, C, T)
920
+ x = x.transpose(1, -1)
921
+ x = self.out_proj(x.mean(1)) # (B, H)
922
+ return x
923
+
924
+
925
+ class ConvDecoder(nn.Module):
926
+ def __init__(self, c, dropout, kernel_size=9, act='gelu'):
927
+ super().__init__()
928
+ self.c = c
929
+ self.dropout = dropout
930
+
931
+ self.pre_convs = nn.ModuleList()
932
+ self.pre_lns = nn.ModuleList()
933
+ for i in range(2):
934
+ self.pre_convs.append(TransformerFFNLayer(
935
+ c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
936
+ self.pre_lns.append(LayerNorm(c))
937
+
938
+ self.layer_norm_attn = LayerNorm(c)
939
+ self.encoder_attn = MultiheadAttention(c, 1, encoder_decoder_attention=True, bias=False)
940
+
941
+ self.post_convs = nn.ModuleList()
942
+ self.post_lns = nn.ModuleList()
943
+ for i in range(8):
944
+ self.post_convs.append(TransformerFFNLayer(
945
+ c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
946
+ self.post_lns.append(LayerNorm(c))
947
+
948
+ def forward(
949
+ self,
950
+ x,
951
+ encoder_out=None,
952
+ encoder_padding_mask=None,
953
+ incremental_state=None,
954
+ **kwargs,
955
+ ):
956
+ attn_logits = None
957
+ for conv, ln in zip(self.pre_convs, self.pre_lns):
958
+ residual = x
959
+ x = ln(x)
960
+ x = conv(x) + residual
961
+ if encoder_out is not None:
962
+ residual = x
963
+ x = self.layer_norm_attn(x)
964
+ x, attn = self.encoder_attn(
965
+ query=x,
966
+ key=encoder_out,
967
+ value=encoder_out,
968
+ key_padding_mask=encoder_padding_mask,
969
+ incremental_state=incremental_state,
970
+ static_kv=True,
971
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
972
+ 'enc_dec_attn_constraint_mask'),
973
+ )
974
+ attn_logits = attn[1]
975
+ x = F.dropout(x, self.dropout, training=self.training)
976
+ x = residual + x
977
+ for conv, ln in zip(self.post_convs, self.post_lns):
978
+ residual = x
979
+ x = ln(x)
980
+ x = conv(x) + residual
981
+ return x, attn_logits
982
+
983
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
984
+ self.encoder_attn.clear_buffer(incremental_state)
985
+ self.ffn.clear_buffer(incremental_state)
986
+
987
+ def set_buffer(self, name, tensor, incremental_state):
988
+ return set_incremental_state(self, incremental_state, name, tensor)
modules/audio2motion/transformer_models.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import isin
2
+ import torch
3
+ import torch.nn as nn
4
+ from modules.audio2motion.transformer_base import *
5
+
6
+ DEFAULT_MAX_SOURCE_POSITIONS = 2000
7
+ DEFAULT_MAX_TARGET_POSITIONS = 2000
8
+
9
+
10
+ class TransformerEncoderLayer(nn.Module):
11
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
12
+ super().__init__()
13
+ self.hidden_size = hidden_size
14
+ self.dropout = dropout
15
+ self.num_heads = num_heads
16
+ self.op = EncSALayer(
17
+ hidden_size, num_heads, dropout=dropout,
18
+ attention_dropout=0.0, relu_dropout=dropout,
19
+ kernel_size=kernel_size
20
+ if kernel_size is not None else 9,
21
+ padding='SAME',
22
+ norm=norm, act='gelu'
23
+ )
24
+
25
+ def forward(self, x, **kwargs):
26
+ return self.op(x, **kwargs)
27
+
28
+
29
+ ######################
30
+ # fastspeech modules
31
+ ######################
32
+ class LayerNorm(torch.nn.LayerNorm):
33
+ """Layer normalization module.
34
+ :param int nout: output dim size
35
+ :param int dim: dimension to be normalized
36
+ """
37
+
38
+ def __init__(self, nout, dim=-1, eps=1e-5):
39
+ """Construct an LayerNorm object."""
40
+ super(LayerNorm, self).__init__(nout, eps=eps)
41
+ self.dim = dim
42
+
43
+ def forward(self, x):
44
+ """Apply layer normalization.
45
+ :param torch.Tensor x: input tensor
46
+ :return: layer normalized tensor
47
+ :rtype torch.Tensor
48
+ """
49
+ if self.dim == -1:
50
+ return super(LayerNorm, self).forward(x)
51
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
52
+
53
+
54
+ class FFTBlocks(nn.Module):
55
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None,
56
+ num_heads=2, use_pos_embed=True, use_last_norm=True, norm='ln',
57
+ use_pos_embed_alpha=True):
58
+ super().__init__()
59
+ self.num_layers = num_layers
60
+ embed_dim = self.hidden_size = hidden_size
61
+ self.dropout = dropout if dropout is not None else 0.1
62
+ self.use_pos_embed = use_pos_embed
63
+ self.use_last_norm = use_last_norm
64
+ if use_pos_embed:
65
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
66
+ self.padding_idx = 0
67
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
68
+ self.embed_positions = SinusoidalPositionalEmbedding(
69
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
70
+ )
71
+
72
+ self.layers = nn.ModuleList([])
73
+ self.layers.extend([
74
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
75
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
76
+ norm=norm)
77
+ for _ in range(self.num_layers)
78
+ ])
79
+ if self.use_last_norm:
80
+ if norm == 'ln':
81
+ self.layer_norm = nn.LayerNorm(embed_dim)
82
+ elif norm == 'bn':
83
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
84
+ elif norm == 'gn':
85
+ self.layer_norm = GroupNorm1DTBC(8, embed_dim)
86
+ else:
87
+ self.layer_norm = None
88
+
89
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
90
+ """
91
+ :param x: [B, T, C]
92
+ :param padding_mask: [B, T]
93
+ :return: [B, T, C] or [L, B, T, C]
94
+ """
95
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
96
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
97
+ if self.use_pos_embed:
98
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
99
+ x = x + positions
100
+ x = F.dropout(x, p=self.dropout, training=self.training)
101
+ # B x T x C -> T x B x C
102
+ x = x.transpose(0, 1) * nonpadding_mask_TB
103
+ hiddens = []
104
+ for layer in self.layers:
105
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
106
+ hiddens.append(x)
107
+ if self.use_last_norm:
108
+ x = self.layer_norm(x) * nonpadding_mask_TB
109
+ if return_hiddens:
110
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
111
+ x = x.transpose(1, 2) # [L, B, T, C]
112
+ else:
113
+ x = x.transpose(0, 1) # [B, T, C]
114
+ return x
115
+
116
+ class SequentialSA(nn.Module):
117
+ def __init__(self,layers):
118
+ super(SequentialSA,self).__init__()
119
+ self.layers = nn.ModuleList(layers)
120
+
121
+ def forward(self,x,x_mask):
122
+ """
123
+ x: [batch, T, H]
124
+ x_mask: [batch, T]
125
+ """
126
+ pad_mask = 1. - x_mask
127
+ for layer in self.layers:
128
+ if isinstance(layer, EncSALayer):
129
+ x = x.permute(1,0,2)
130
+ x = layer(x,pad_mask)
131
+ x = x.permute(1,0,2)
132
+ elif isinstance(layer, nn.Linear):
133
+ x = layer(x) * x_mask.unsqueeze(2)
134
+ elif isinstance(layer, nn.AvgPool1d):
135
+ x = x.permute(0,2,1)
136
+ x = layer(x)
137
+ x = x.permute(0,2,1)
138
+ elif isinstance(layer, nn.PReLU):
139
+ bs, t, hid = x.shape
140
+ x = x.reshape([bs*t,hid])
141
+ x = layer(x)
142
+ x = x.reshape([bs, t, hid])
143
+ else: # Relu
144
+ x = layer(x)
145
+
146
+ return x
147
+
148
+ class TransformerStyleFusionModel(nn.Module):
149
+ def __init__(self, num_heads=4, dropout = 0.1, out_dim = 64):
150
+ super(TransformerStyleFusionModel, self).__init__()
151
+ self.audio_layer = SequentialSA([
152
+ nn.Linear(29, 48),
153
+ nn.ReLU(48),
154
+ nn.Linear(48, 128),
155
+ ])
156
+
157
+ self.energy_layer = SequentialSA([
158
+ nn.Linear(1, 16),
159
+ nn.ReLU(16),
160
+ nn.Linear(16, 64),
161
+ ])
162
+
163
+ self.backbone1 = FFTBlocks(hidden_size=192,num_layers=3)
164
+
165
+ self.sty_encoder = nn.Sequential(*[
166
+ nn.Linear(135, 64),
167
+ nn.ReLU(),
168
+ nn.Linear(64, 128)
169
+ ])
170
+
171
+ self.backbone2 = FFTBlocks(hidden_size=320,num_layers=3)
172
+
173
+ self.out_layer = SequentialSA([
174
+ nn.AvgPool1d(kernel_size=2,stride=2,padding=0), #[b,hid,t_audio]=>[b,hid,t_audio//2]
175
+ nn.Linear(320,out_dim),
176
+ nn.PReLU(out_dim),
177
+ nn.Linear(out_dim,out_dim),
178
+ ])
179
+
180
+ self.dropout = nn.Dropout(p = dropout)
181
+
182
+ def forward(self, audio, energy, style, x_mask, y_mask):
183
+ pad_mask = 1. - x_mask
184
+ audio_feat = self.audio_layer(audio, x_mask)
185
+ energy_feat = self.energy_layer(energy, x_mask)
186
+ feat = torch.cat((audio_feat, energy_feat), dim=-1) # [batch, T, H=48+16]
187
+ feat = self.backbone1(feat, pad_mask)
188
+ feat = self.dropout(feat)
189
+
190
+ sty_feat = self.sty_encoder(style) # [batch,135]=>[batch, H=64]
191
+ sty_feat = sty_feat.unsqueeze(1).repeat(1, feat.shape[1], 1) # [batch, T, H=64]
192
+
193
+ feat = torch.cat([feat, sty_feat], dim=-1) # [batch, T, H=64+64]
194
+ feat = self.backbone2(feat, pad_mask) # [batch, T, H=128]
195
+ out = self.out_layer(feat, y_mask) # [batch, T//2, H=out_dim]
196
+
197
+ return out
198
+
199
+
200
+ if __name__ == '__main__':
201
+ model = TransformerStyleFusionModel()
202
+ audio = torch.rand(4,200,29) # [B,T,H]
203
+ energy = torch.rand(4,200,1) # [B,T,H]
204
+ style = torch.ones(4,135) # [B,T]
205
+ x_mask = torch.ones(4,200) # [B,T]
206
+ x_mask[3,10:] = 0
207
+ ret = model(audio,energy,style, x_mask)
208
+ print(" ")
modules/audio2motion/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def squeeze(x, x_mask=None, n_sqz=2):
5
+ b, c, t = x.size()
6
+
7
+ t = (t // n_sqz) * n_sqz
8
+ x = x[:, :, :t]
9
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
10
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
11
+
12
+ if x_mask is not None:
13
+ x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
14
+ else:
15
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
16
+ return x_sqz * x_mask, x_mask
17
+
18
+
19
+ def unsqueeze(x, x_mask=None, n_sqz=2):
20
+ b, c, t = x.size()
21
+
22
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
23
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
24
+
25
+ if x_mask is not None:
26
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
27
+ else:
28
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
29
+ return x_unsqz * x_mask, x_mask