ameerazam08
commited on
Commit
•
a5c5b03
1
Parent(s):
fbeb913
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- .gitignore +199 -0
- README-zh.md +137 -0
- README.md +137 -0
- checkpoints/.gitkeep +0 -0
- data_gen/eg3d/convert_to_eg3d_convention.py +146 -0
- data_gen/runs/binarizer_nerf.py +335 -0
- data_gen/runs/nerf/process_guide.md +49 -0
- data_gen/runs/nerf/run.sh +51 -0
- data_gen/utils/mp_feature_extractors/face_landmarker.py +130 -0
- data_gen/utils/mp_feature_extractors/face_landmarker.task +3 -0
- data_gen/utils/mp_feature_extractors/mp_segmenter.py +274 -0
- data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite +3 -0
- data_gen/utils/path_converter.py +24 -0
- data_gen/utils/process_audio/extract_hubert.py +95 -0
- data_gen/utils/process_audio/extract_mel_f0.py +148 -0
- data_gen/utils/process_audio/resample_audio_to_16k.py +49 -0
- data_gen/utils/process_image/extract_lm2d.py +197 -0
- data_gen/utils/process_image/extract_segment_imgs.py +114 -0
- data_gen/utils/process_image/fit_3dmm_landmark.py +369 -0
- data_gen/utils/process_video/euler2quaterion.py +35 -0
- data_gen/utils/process_video/extract_blink.py +50 -0
- data_gen/utils/process_video/extract_lm2d.py +164 -0
- data_gen/utils/process_video/extract_segment_imgs.py +500 -0
- data_gen/utils/process_video/fit_3dmm_landmark.py +565 -0
- data_gen/utils/process_video/inpaint_torso_imgs.py +193 -0
- data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py +87 -0
- data_gen/utils/process_video/split_video_to_imgs.py +53 -0
- data_util/face3d_helper.py +309 -0
- deep_3drecon/BFM/.gitkeep +0 -0
- deep_3drecon/bfm_left_eye_faces.npy +3 -0
- deep_3drecon/bfm_right_eye_faces.npy +3 -0
- deep_3drecon/deep_3drecon_models/bfm.py +426 -0
- deep_3drecon/ncc_code.npy +3 -0
- deep_3drecon/secc_renderer.py +78 -0
- deep_3drecon/util/mesh_renderer.py +131 -0
- docs/prepare_env/install_guide-zh.md +35 -0
- docs/prepare_env/install_guide.md +34 -0
- docs/prepare_env/requirements.txt +75 -0
- inference/app_real3dportrait.py +244 -0
- inference/edit_secc.py +147 -0
- inference/infer_utils.py +154 -0
- inference/real3d_infer.py +542 -0
- insta.sh +18 -0
- modules/audio2motion/cnn_models.py +359 -0
- modules/audio2motion/flow_base.py +838 -0
- modules/audio2motion/multi_length_disc.py +340 -0
- modules/audio2motion/transformer_base.py +988 -0
- modules/audio2motion/transformer_models.py +208 -0
- modules/audio2motion/utils.py +29 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data_gen/utils/mp_feature_extractors/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
|
37 |
+
pytorch3d/.github/bundle_adjust.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
pytorch3d/.github/camera_position_teapot.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
pytorch3d/.github/fit_nerf.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
pytorch3d/.github/fit_textured_volume.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
pytorch3d/.github/implicitron_config.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
pytorch3d/.github/nerf_project_logo.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
pytorch3d/docs/notes/assets/batch_modes.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# big files
|
2 |
+
data_util/face_tracking/3DMM/01_MorphableModel.mat
|
3 |
+
data_util/face_tracking/3DMM/3DMM_info.npy
|
4 |
+
|
5 |
+
!/deep_3drecon/BFM/.gitkeep
|
6 |
+
deep_3drecon/BFM/Exp_Pca.bin
|
7 |
+
deep_3drecon/BFM/01_MorphableModel.mat
|
8 |
+
deep_3drecon/BFM/BFM_model_front.mat
|
9 |
+
deep_3drecon/network/FaceReconModel.pb
|
10 |
+
deep_3drecon/checkpoints/*
|
11 |
+
|
12 |
+
.vscode
|
13 |
+
### Project ignore
|
14 |
+
/checkpoints/*
|
15 |
+
!/checkpoints/.gitkeep
|
16 |
+
/data/*
|
17 |
+
!/data/.gitkeep
|
18 |
+
infer_out
|
19 |
+
rsync
|
20 |
+
.idea
|
21 |
+
.DS_Store
|
22 |
+
bak
|
23 |
+
tmp
|
24 |
+
*.tar.gz
|
25 |
+
mos
|
26 |
+
nbs
|
27 |
+
/configs_usr/*
|
28 |
+
!/configs_usr/.gitkeep
|
29 |
+
/egs_usr/*
|
30 |
+
!/egs_usr/.gitkeep
|
31 |
+
/rnnoise
|
32 |
+
#/usr/*
|
33 |
+
#!/usr/.gitkeep
|
34 |
+
scripts_usr
|
35 |
+
|
36 |
+
# Created by .ignore support plugin (hsz.mobi)
|
37 |
+
### Python template
|
38 |
+
# Byte-compiled / optimized / DLL files
|
39 |
+
__pycache__/
|
40 |
+
*.py[cod]
|
41 |
+
*$py.class
|
42 |
+
|
43 |
+
# C extensions
|
44 |
+
*.so
|
45 |
+
|
46 |
+
# Distribution / packaging
|
47 |
+
.Python
|
48 |
+
build/
|
49 |
+
develop-eggs/
|
50 |
+
dist/
|
51 |
+
downloads/
|
52 |
+
eggs/
|
53 |
+
.eggs/
|
54 |
+
lib/
|
55 |
+
lib64/
|
56 |
+
parts/
|
57 |
+
sdist/
|
58 |
+
var/
|
59 |
+
wheels/
|
60 |
+
pip-wheel-metadata/
|
61 |
+
share/python-wheels/
|
62 |
+
*.egg-info/
|
63 |
+
.installed.cfg
|
64 |
+
*.egg
|
65 |
+
MANIFEST
|
66 |
+
|
67 |
+
# PyInstaller
|
68 |
+
# Usually these files are written by a python script from a template
|
69 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
70 |
+
*.manifest
|
71 |
+
*.spec
|
72 |
+
|
73 |
+
# Installer logs
|
74 |
+
pip-log.txt
|
75 |
+
pip-delete-this-directory.txt
|
76 |
+
|
77 |
+
# Unit test / coverage reports
|
78 |
+
htmlcov/
|
79 |
+
.tox/
|
80 |
+
.nox/
|
81 |
+
.coverage
|
82 |
+
.coverage.*
|
83 |
+
.cache
|
84 |
+
nosetests.xml
|
85 |
+
coverage.xml
|
86 |
+
*.cover
|
87 |
+
.hypothesis/
|
88 |
+
.pytest_cache/
|
89 |
+
|
90 |
+
# Translations
|
91 |
+
*.mo
|
92 |
+
*.pot
|
93 |
+
|
94 |
+
# Django stuff:
|
95 |
+
*.log
|
96 |
+
local_settings.py
|
97 |
+
db.sqlite3
|
98 |
+
db.sqlite3-journal
|
99 |
+
|
100 |
+
# Flask stuff:
|
101 |
+
instance/
|
102 |
+
.webassets-cache
|
103 |
+
|
104 |
+
# Scrapy stuff:
|
105 |
+
.scrapy
|
106 |
+
|
107 |
+
# Sphinx documentation
|
108 |
+
docs/_build/
|
109 |
+
|
110 |
+
# PyBuilder
|
111 |
+
target/
|
112 |
+
|
113 |
+
# Jupyter Notebook
|
114 |
+
.ipynb_checkpoints
|
115 |
+
|
116 |
+
# IPython
|
117 |
+
profile_default/
|
118 |
+
ipython_config.py
|
119 |
+
|
120 |
+
# pyenv
|
121 |
+
.python-version
|
122 |
+
|
123 |
+
# pipenv
|
124 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
125 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
126 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
127 |
+
# install all needed dependencies.
|
128 |
+
#Pipfile.lock
|
129 |
+
|
130 |
+
# celery beat schedule file
|
131 |
+
celerybeat-schedule
|
132 |
+
|
133 |
+
# SageMath parsed files
|
134 |
+
*.sage.py
|
135 |
+
|
136 |
+
# Environments
|
137 |
+
.env
|
138 |
+
.venv
|
139 |
+
env/
|
140 |
+
venv/
|
141 |
+
ENV/
|
142 |
+
env.bak/
|
143 |
+
venv.bak/
|
144 |
+
|
145 |
+
# Spyder project settings
|
146 |
+
.spyderproject
|
147 |
+
.spyproject
|
148 |
+
|
149 |
+
# Rope project settings
|
150 |
+
.ropeproject
|
151 |
+
|
152 |
+
# mkdocs documentation
|
153 |
+
/site
|
154 |
+
|
155 |
+
# mypy
|
156 |
+
.mypy_cache/
|
157 |
+
.dmypy.json
|
158 |
+
dmypy.json
|
159 |
+
|
160 |
+
# Pyre type checker
|
161 |
+
.pyre/
|
162 |
+
data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm
|
163 |
+
deep_3drecon/mesh_renderer/bazel-bin
|
164 |
+
deep_3drecon/mesh_renderer/bazel-mesh_renderer
|
165 |
+
deep_3drecon/mesh_renderer/bazel-out
|
166 |
+
deep_3drecon/mesh_renderer/bazel-testlogs
|
167 |
+
|
168 |
+
.nfs*
|
169 |
+
infer_outs/*
|
170 |
+
|
171 |
+
*.pth
|
172 |
+
venv_113/*
|
173 |
+
*.pt
|
174 |
+
experiments/trials
|
175 |
+
flame_3drecon/*
|
176 |
+
|
177 |
+
temp/
|
178 |
+
/kill.sh
|
179 |
+
/datasets
|
180 |
+
data_util/imagenet_classes.txt
|
181 |
+
process_data_May.sh
|
182 |
+
/env_prepare_reproduce.md
|
183 |
+
/my_debug.py
|
184 |
+
|
185 |
+
utils/metrics/shape_predictor_68_face_landmarks.dat
|
186 |
+
*.mp4
|
187 |
+
_torchshow/
|
188 |
+
*.png
|
189 |
+
*.jpg
|
190 |
+
|
191 |
+
*.mrc
|
192 |
+
|
193 |
+
deep_3drecon/BFM/BFM_exp_idx.mat
|
194 |
+
deep_3drecon/BFM/BFM_front_idx.mat
|
195 |
+
deep_3drecon/BFM/facemodel_info.mat
|
196 |
+
deep_3drecon/BFM/index_mp468_from_mesh35709.npy
|
197 |
+
deep_3drecon/BFM/mediapipe_in_bfm53201.npy
|
198 |
+
deep_3drecon/BFM/std_exp.txt
|
199 |
+
!data/raw/examples/*
|
README-zh.md
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/Real3DPortrait
|
3 |
+
)](https://github.com/yerfor/Real3DPortrait) | [English Readme](./README.md)
|
4 |
+
|
5 |
+
这个仓库是Real3D-Portrait的官方PyTorch实现, 用于实现单参考图(one-shot)、高视频真实度(video reality)的虚拟人视频合成。您可以访问我们的[项目页面](https://real3dportrait.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/pdf/2401.08503.pdf)以了解技术细节。
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<br>
|
9 |
+
<img src="assets/real3dportrait.png" width="100%"/>
|
10 |
+
<br>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
# 快速上手!
|
14 |
+
## 安装环境
|
15 |
+
请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`real3dportrait`
|
16 |
+
## 下载预训练与第三方模型
|
17 |
+
### 3DMM BFM模型
|
18 |
+
下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
|
19 |
+
|
20 |
+
|
21 |
+
下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
|
22 |
+
```
|
23 |
+
deep_3drecon/BFM/
|
24 |
+
├── 01_MorphableModel.mat
|
25 |
+
├── BFM_exp_idx.mat
|
26 |
+
├── BFM_front_idx.mat
|
27 |
+
├── BFM_model_front.mat
|
28 |
+
├── Exp_Pca.bin
|
29 |
+
├── facemodel_info.mat
|
30 |
+
├── index_mp468_from_mesh35709.npy
|
31 |
+
├── mediapipe_in_bfm53201.npy
|
32 |
+
└── std_exp.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
### 预训练模型
|
36 |
+
下载预训练的Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) 提取码: 6x4f
|
37 |
+
|
38 |
+
下载完成后,放置全部的文件到`checkpoints`里并解压,文件结构如下:
|
39 |
+
```
|
40 |
+
checkpoints/
|
41 |
+
├── 240126_real3dportrait_orig
|
42 |
+
│ ├── audio2secc_vae
|
43 |
+
│ │ ├── config.yaml
|
44 |
+
│ │ └── model_ckpt_steps_400000.ckpt
|
45 |
+
│ └── secc2plane_torso_orig
|
46 |
+
│ ├── config.yaml
|
47 |
+
│ └── model_ckpt_steps_100000.ckpt
|
48 |
+
└── pretrained_ckpts
|
49 |
+
└── mit_b0.pth
|
50 |
+
```
|
51 |
+
|
52 |
+
## 推理测试
|
53 |
+
我们目前提供了**命令行(CLI)**与**Gradio WebUI**推理方式,并将在未来提供Google Colab方式。我们同时支持音频驱动(Audio-Driven)与视频驱动(Video-Driven):
|
54 |
+
|
55 |
+
- 音频驱动场景下,需要至少提供`source image`与`driving audio`
|
56 |
+
- 视频驱动场景下,需要至少提供`source image`与`driving expression video`
|
57 |
+
|
58 |
+
### Gradio WebUI推理
|
59 |
+
启动Gradio WebUI,按照提示上传素材,点击`Generate`按钮即可推理:
|
60 |
+
```bash
|
61 |
+
python inference/app_real3dportrait.py
|
62 |
+
```
|
63 |
+
|
64 |
+
### 命令行推理
|
65 |
+
首先,切换至项目根目录并启用Conda环境:
|
66 |
+
```bash
|
67 |
+
cd <Real3DPortraitRoot>
|
68 |
+
conda activate real3dportrait
|
69 |
+
export PYTHON_PATH=./
|
70 |
+
```
|
71 |
+
音频驱动场景下,需要至少提供source image与driving audio,推理指令:
|
72 |
+
```bash
|
73 |
+
python inference/real3d_infer.py \
|
74 |
+
--src_img <PATH_TO_SOURCE_IMAGE> \
|
75 |
+
--drv_aud <PATH_TO_AUDIO> \
|
76 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
77 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
78 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
79 |
+
```
|
80 |
+
视频驱动场景下,需要至少提供source image与driving expression video(作为drv_aud参数),推理指令:
|
81 |
+
```bash
|
82 |
+
python inference/real3d_infer.py \
|
83 |
+
--src_img <PATH_TO_SOURCE_IMAGE> \
|
84 |
+
--drv_aud <PATH_TO_EXP_VIDEO> \
|
85 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
86 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
87 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
88 |
+
```
|
89 |
+
一些可选参数注释:
|
90 |
+
- `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
|
91 |
+
- `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
|
92 |
+
- `--mouth_amp` 嘴部张幅参数,值越大张幅越大
|
93 |
+
- `--map_to_init_pose` 值为`True`时,首帧的pose将被映射到source pose,后续帧也作相同变换
|
94 |
+
- `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
|
95 |
+
- `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
|
96 |
+
- `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
|
97 |
+
|
98 |
+
指令示例:
|
99 |
+
```bash
|
100 |
+
python inference/real3d_infer.py \
|
101 |
+
--src_img data/raw/examples/Macron.png \
|
102 |
+
--drv_aud data/raw/examples/Obama_5s.wav \
|
103 |
+
--drv_pose data/raw/examples/May_5s.mp4 \
|
104 |
+
--bg_img data/raw/examples/bg.png \
|
105 |
+
--out_name output.mp4 \
|
106 |
+
--out_mode concat_debug
|
107 |
+
```
|
108 |
+
|
109 |
+
## ToDo
|
110 |
+
- [x] **Release Pre-trained weights of Real3D-Portrait.**
|
111 |
+
- [x] **Release Inference Code of Real3D-Portrait.**
|
112 |
+
- [x] **Release Gradio Demo of Real3D-Portrait..**
|
113 |
+
- [ ] **Release Google Colab of Real3D-Portrait..**
|
114 |
+
- [ ] **Release Training Code of Real3D-Portrait.**
|
115 |
+
|
116 |
+
# 引用我们
|
117 |
+
如果这个仓库对你有帮助,请考虑引用我们��工作:
|
118 |
+
```
|
119 |
+
@article{ye2024real3d,
|
120 |
+
title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
|
121 |
+
author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
|
122 |
+
journal={arXiv preprint arXiv:2401.08503},
|
123 |
+
year={2024}
|
124 |
+
}
|
125 |
+
@article{ye2023geneface++,
|
126 |
+
title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
|
127 |
+
author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
|
128 |
+
journal={arXiv preprint arXiv:2305.00787},
|
129 |
+
year={2023}
|
130 |
+
}
|
131 |
+
@article{ye2023geneface,
|
132 |
+
title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
|
133 |
+
author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
|
134 |
+
journal={arXiv preprint arXiv:2301.13430},
|
135 |
+
year={2023}
|
136 |
+
}
|
137 |
+
```
|
README.md
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-%3CCOLOR%3E.svg)](https://arxiv.org/abs/2401.08503)| [![GitHub Stars](https://img.shields.io/github/stars/yerfor/Real3DPortrait
|
3 |
+
)](https://github.com/yerfor/Real3DPortrait) | [中文文档](./README-zh.md)
|
4 |
+
|
5 |
+
This is the official repo of Real3D-Portrait with Pytorch implementation, for one-shot and high video reality talking portrait synthesis. You can visit our [Demo Page](https://real3dportrait.github.io/) for watching demo videos, and read our [Paper](https://arxiv.org/pdf/2401.08503.pdf) for technical details.
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<br>
|
9 |
+
<img src="assets/real3dportrait.png" width="100%"/>
|
10 |
+
<br>
|
11 |
+
</p>
|
12 |
+
|
13 |
+
# Quick Start!
|
14 |
+
## Environment Installation
|
15 |
+
Please refer to [Installation Guide](docs/prepare_env/install_guide.md), prepare a Conda environment `real3dportrait`.
|
16 |
+
## Download Pre-trained & Third-Party Models
|
17 |
+
### 3DMM BFM Model
|
18 |
+
Download 3DMM BFM Model from [Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) with Password m9q5.
|
19 |
+
|
20 |
+
|
21 |
+
Put all the files in `deep_3drecon/BFM`, the file structure will be like this:
|
22 |
+
```
|
23 |
+
deep_3drecon/BFM/
|
24 |
+
├── 01_MorphableModel.mat
|
25 |
+
├── BFM_exp_idx.mat
|
26 |
+
├── BFM_front_idx.mat
|
27 |
+
├── BFM_model_front.mat
|
28 |
+
├── Exp_Pca.bin
|
29 |
+
├── facemodel_info.mat
|
30 |
+
├── index_mp468_from_mesh35709.npy
|
31 |
+
├── mediapipe_in_bfm53201.npy
|
32 |
+
└── std_exp.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
### Pre-trained Real3D-Portrait
|
36 |
+
Download Pre-trained Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) with Password 6x4f
|
37 |
+
|
38 |
+
Put the zip files in `checkpoints` and unzip them, the file structure will be like this:
|
39 |
+
```
|
40 |
+
checkpoints/
|
41 |
+
├── 240126_real3dportrait_orig
|
42 |
+
│ ├── audio2secc_vae
|
43 |
+
│ │ ├── config.yaml
|
44 |
+
│ │ └── model_ckpt_steps_400000.ckpt
|
45 |
+
│ └── secc2plane_torso_orig
|
46 |
+
│ ├── config.yaml
|
47 |
+
│ └── model_ckpt_steps_100000.ckpt
|
48 |
+
└── pretrained_ckpts
|
49 |
+
└── mit_b0.pth
|
50 |
+
```
|
51 |
+
|
52 |
+
## Inference
|
53 |
+
Currently, we provide **CLI** and **Gradio WebUI** for inference, and Google Colab will be provided in the future. We support both Audio-Driven and Video-Driven methods:
|
54 |
+
|
55 |
+
- For audio-driven, at least prepare `source image` and `driving audio`
|
56 |
+
- For video-driven, at least prepare `source image` and `driving expression video`
|
57 |
+
|
58 |
+
### Gradio WebUI
|
59 |
+
Run Gradio WebUI demo, upload resouces in webpage,click `Generate` button to inference:
|
60 |
+
```bash
|
61 |
+
python inference/app_real3dportrait.py
|
62 |
+
```
|
63 |
+
|
64 |
+
### CLI Inference
|
65 |
+
Firstly, switch to project folder and activate conda environment:
|
66 |
+
```bash
|
67 |
+
cd <Real3DPortraitRoot>
|
68 |
+
conda activate real3dportrait
|
69 |
+
export PYTHON_PATH=./
|
70 |
+
```
|
71 |
+
For audio-driven, provide source image and driving audio:
|
72 |
+
```bash
|
73 |
+
python inference/real3d_infer.py \
|
74 |
+
--src_img <PATH_TO_SOURCE_IMAGE> \
|
75 |
+
--drv_aud <PATH_TO_AUDIO> \
|
76 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
77 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
78 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
79 |
+
```
|
80 |
+
For video-driven, provide source image and driving expression video(as `--drv_aud` parameter):
|
81 |
+
```bash
|
82 |
+
python inference/real3d_infer.py \
|
83 |
+
--src_img <PATH_TO_SOURCE_IMAGE> \
|
84 |
+
--drv_aud <PATH_TO_EXP_VIDEO> \
|
85 |
+
--drv_pose <PATH_TO_POSE_VIDEO, OPTIONAL> \
|
86 |
+
--bg_img <PATH_TO_BACKGROUND_IMAGE, OPTIONAL> \
|
87 |
+
--out_name <PATH_TO_OUTPUT_VIDEO, OPTIONAL>
|
88 |
+
```
|
89 |
+
Some optional parameters:
|
90 |
+
- `--drv_pose` provide motion pose information, default to be static poses
|
91 |
+
- `--bg_img` provide background information, default to be image extracted from source
|
92 |
+
- `--mouth_amp` mouth amplitude, higher value leads to wider mouth
|
93 |
+
- `--map_to_init_pose` when set to `True`, the initial pose will be mapped to source pose, and other poses will be equally transformed
|
94 |
+
- `--temperature` stands for the sampling temperature of audio2motion, higher for more diverse results at the expense of lower accuracy
|
95 |
+
- `--out_name` When not assigned, the results will be stored at `infer_out/tmp/`.
|
96 |
+
- `--out_mode` When `final`, only outputs the final result; when `concat_debug`, also outputs visualization of several intermediate process.
|
97 |
+
|
98 |
+
Commandline example:
|
99 |
+
```bash
|
100 |
+
python inference/real3d_infer.py \
|
101 |
+
--src_img data/raw/examples/Macron.png \
|
102 |
+
--drv_aud data/raw/examples/Obama_5s.wav \
|
103 |
+
--drv_pose data/raw/examples/May_5s.mp4 \
|
104 |
+
--bg_img data/raw/examples/bg.png \
|
105 |
+
--out_name output.mp4 \
|
106 |
+
--out_mode concat_debug
|
107 |
+
```
|
108 |
+
|
109 |
+
# ToDo
|
110 |
+
- [x] **Release Pre-trained weights of Real3D-Portrait.**
|
111 |
+
- [x] **Release Inference Code of Real3D-Portrait.**
|
112 |
+
- [x] **Release Gradio Demo of Real3D-Portrait..**
|
113 |
+
- [ ] **Release Google Colab of Real3D-Portrait..**
|
114 |
+
- [ ] **Release Training Code of Real3D-Portrait.**
|
115 |
+
|
116 |
+
# Citation
|
117 |
+
If you found this repo helpful to your work, please consider cite us:
|
118 |
+
```
|
119 |
+
@article{ye2024real3d,
|
120 |
+
title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
|
121 |
+
author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
|
122 |
+
journal={arXiv preprint arXiv:2401.08503},
|
123 |
+
year={2024}
|
124 |
+
}
|
125 |
+
@article{ye2023geneface++,
|
126 |
+
title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
|
127 |
+
author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
|
128 |
+
journal={arXiv preprint arXiv:2305.00787},
|
129 |
+
year={2023}
|
130 |
+
}
|
131 |
+
@article{ye2023geneface,
|
132 |
+
title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
|
133 |
+
author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
|
134 |
+
journal={arXiv preprint arXiv:2301.13430},
|
135 |
+
year={2023}
|
136 |
+
}
|
137 |
+
```
|
checkpoints/.gitkeep
ADDED
File without changes
|
data_gen/eg3d/convert_to_eg3d_convention.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
|
5 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
6 |
+
|
7 |
+
|
8 |
+
def _fix_intrinsics(intrinsics):
|
9 |
+
"""
|
10 |
+
intrinsics: [3,3], not batch-wise
|
11 |
+
"""
|
12 |
+
# unnormalized normalized
|
13 |
+
|
14 |
+
# [[ f_x, s=0, x_0] [[ f_x/size_x, s=0, x_0/size_x=0.5]
|
15 |
+
# [ 0, f_y, y_0] -> [ 0, f_y/size_y, y_0/size_y=0.5]
|
16 |
+
# [ 0, 0, 1 ]] [ 0, 0, 1 ]]
|
17 |
+
intrinsics = np.array(intrinsics).copy()
|
18 |
+
assert intrinsics.shape == (3, 3), intrinsics
|
19 |
+
intrinsics[0,0] = 2985.29/700
|
20 |
+
intrinsics[1,1] = 2985.29/700
|
21 |
+
intrinsics[0,2] = 1/2
|
22 |
+
intrinsics[1,2] = 1/2
|
23 |
+
assert intrinsics[0,1] == 0
|
24 |
+
assert intrinsics[2,2] == 1
|
25 |
+
assert intrinsics[1,0] == 0
|
26 |
+
assert intrinsics[2,0] == 0
|
27 |
+
assert intrinsics[2,1] == 0
|
28 |
+
return intrinsics
|
29 |
+
|
30 |
+
# Used in original submission
|
31 |
+
def _fix_pose_orig(pose):
|
32 |
+
"""
|
33 |
+
pose: [4,4], not batch-wise
|
34 |
+
"""
|
35 |
+
pose = np.array(pose).copy()
|
36 |
+
location = pose[:3, 3]
|
37 |
+
radius = np.linalg.norm(location)
|
38 |
+
pose[:3, 3] = pose[:3, 3]/radius * 2.7
|
39 |
+
return pose
|
40 |
+
|
41 |
+
|
42 |
+
def get_eg3d_convention_camera_pose_intrinsic(item):
|
43 |
+
"""
|
44 |
+
item: a dict during binarize
|
45 |
+
|
46 |
+
"""
|
47 |
+
if item['euler'].ndim == 1:
|
48 |
+
angle = convert_to_tensor(copy.copy(item['euler']))
|
49 |
+
trans = copy.deepcopy(item['trans'])
|
50 |
+
|
51 |
+
# handle the difference of euler axis between eg3d and ours
|
52 |
+
# see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
|
53 |
+
# angle += torch.tensor([0, 3.1415926535, 3.1415926535], device=angle.device)
|
54 |
+
R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
|
55 |
+
trans[2] += -10
|
56 |
+
c = -np.dot(R, trans)
|
57 |
+
pose = np.eye(4)
|
58 |
+
pose[:3,:3] = R
|
59 |
+
c *= 0.27 # normalize camera radius
|
60 |
+
c[1] += 0.006 # additional offset used in submission
|
61 |
+
c[2] += 0.161 # additional offset used in submission
|
62 |
+
pose[0,3] = c[0]
|
63 |
+
pose[1,3] = c[1]
|
64 |
+
pose[2,3] = c[2]
|
65 |
+
|
66 |
+
focal = 2985.29 # = 1015*1024/224*(300/466.285),
|
67 |
+
# todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
|
68 |
+
pp = 512#112
|
69 |
+
w = 1024#224
|
70 |
+
h = 1024#224
|
71 |
+
|
72 |
+
K = np.eye(3)
|
73 |
+
K[0][0] = focal
|
74 |
+
K[1][1] = focal
|
75 |
+
K[0][2] = w/2.0
|
76 |
+
K[1][2] = h/2.0
|
77 |
+
convention_K = _fix_intrinsics(K)
|
78 |
+
|
79 |
+
Rot = np.eye(3)
|
80 |
+
Rot[0, 0] = 1
|
81 |
+
Rot[1, 1] = -1
|
82 |
+
Rot[2, 2] = -1
|
83 |
+
pose[:3, :3] = np.dot(pose[:3, :3], Rot) # permute axes
|
84 |
+
convention_pose = _fix_pose_orig(pose)
|
85 |
+
|
86 |
+
item['c2w'] = pose
|
87 |
+
item['convention_c2w'] = convention_pose
|
88 |
+
item['intrinsics'] = convention_K
|
89 |
+
return item
|
90 |
+
else:
|
91 |
+
num_samples = len(item['euler'])
|
92 |
+
eulers_all = convert_to_tensor(copy.deepcopy(item['euler'])) # [B, 3]
|
93 |
+
trans_all = copy.deepcopy(item['trans']) # [B, 3]
|
94 |
+
|
95 |
+
# handle the difference of euler axis between eg3d and ours
|
96 |
+
# see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
|
97 |
+
# eulers_all += torch.tensor([0, 3.1415926535, 3.1415926535], device=eulers_all.device).unsqueeze(0).repeat([eulers_all.shape[0],1])
|
98 |
+
|
99 |
+
intrinsics = []
|
100 |
+
poses = []
|
101 |
+
convention_poses = []
|
102 |
+
for i in range(num_samples):
|
103 |
+
angle = eulers_all[i]
|
104 |
+
trans = trans_all[i]
|
105 |
+
R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
|
106 |
+
trans[2] += -10
|
107 |
+
c = -np.dot(R, trans)
|
108 |
+
pose = np.eye(4)
|
109 |
+
pose[:3,:3] = R
|
110 |
+
c *= 0.27 # normalize camera radius
|
111 |
+
c[1] += 0.006 # additional offset used in submission
|
112 |
+
c[2] += 0.161 # additional offset used in submission
|
113 |
+
pose[0,3] = c[0]
|
114 |
+
pose[1,3] = c[1]
|
115 |
+
pose[2,3] = c[2]
|
116 |
+
|
117 |
+
focal = 2985.29 # = 1015*1024/224*(300/466.285),
|
118 |
+
# todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
|
119 |
+
pp = 512#112
|
120 |
+
w = 1024#224
|
121 |
+
h = 1024#224
|
122 |
+
|
123 |
+
K = np.eye(3)
|
124 |
+
K[0][0] = focal
|
125 |
+
K[1][1] = focal
|
126 |
+
K[0][2] = w/2.0
|
127 |
+
K[1][2] = h/2.0
|
128 |
+
convention_K = _fix_intrinsics(K)
|
129 |
+
intrinsics.append(convention_K)
|
130 |
+
|
131 |
+
Rot = np.eye(3)
|
132 |
+
Rot[0, 0] = 1
|
133 |
+
Rot[1, 1] = -1
|
134 |
+
Rot[2, 2] = -1
|
135 |
+
pose[:3, :3] = np.dot(pose[:3, :3], Rot)
|
136 |
+
convention_pose = _fix_pose_orig(pose)
|
137 |
+
convention_poses.append(convention_pose)
|
138 |
+
poses.append(pose)
|
139 |
+
|
140 |
+
intrinsics = np.stack(intrinsics) # [B, 3, 3]
|
141 |
+
poses = np.stack(poses) # [B, 4, 4]
|
142 |
+
convention_poses = np.stack(convention_poses) # [B, 4, 4]
|
143 |
+
item['intrinsics'] = intrinsics
|
144 |
+
item['c2w'] = poses
|
145 |
+
item['convention_c2w'] = convention_poses
|
146 |
+
return item
|
data_gen/runs/binarizer_nerf.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import imageio
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
from data_util.face3d_helper import Face3DHelper
|
11 |
+
from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans
|
12 |
+
from data_gen.utils.process_video.euler2quaterion import euler2quaterion, quaterion2euler
|
13 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
14 |
+
|
15 |
+
|
16 |
+
def euler2rot(euler_angle):
|
17 |
+
batch_size = euler_angle.shape[0]
|
18 |
+
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
19 |
+
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
20 |
+
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
21 |
+
one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
|
22 |
+
zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
|
23 |
+
rot_x = torch.cat((
|
24 |
+
torch.cat((one, zero, zero), 1),
|
25 |
+
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
26 |
+
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
27 |
+
), 2)
|
28 |
+
rot_y = torch.cat((
|
29 |
+
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
30 |
+
torch.cat((zero, one, zero), 1),
|
31 |
+
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
32 |
+
), 2)
|
33 |
+
rot_z = torch.cat((
|
34 |
+
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
35 |
+
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
36 |
+
torch.cat((zero, zero, one), 1)
|
37 |
+
), 2)
|
38 |
+
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
39 |
+
|
40 |
+
|
41 |
+
def rot2euler(rot_mat):
|
42 |
+
batch_size = len(rot_mat)
|
43 |
+
# we assert that y in in [-0.5pi, 0.5pi]
|
44 |
+
cos_y = torch.sqrt(rot_mat[:, 1, 2] * rot_mat[:, 1, 2] + rot_mat[:, 2, 2] * rot_mat[:, 2, 2])
|
45 |
+
theta_x = torch.atan2(-rot_mat[:, 1, 2], rot_mat[:, 2, 2])
|
46 |
+
theta_y = torch.atan2(rot_mat[:, 2, 0], cos_y)
|
47 |
+
theta_z = torch.atan2(rot_mat[:, 0, 1], rot_mat[:, 0, 0])
|
48 |
+
euler_angles = torch.zeros([batch_size, 3])
|
49 |
+
euler_angles[:, 0] = theta_x
|
50 |
+
euler_angles[:, 1] = theta_y
|
51 |
+
euler_angles[:, 2] = theta_z
|
52 |
+
return euler_angles
|
53 |
+
|
54 |
+
index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
|
55 |
+
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
|
56 |
+
|
57 |
+
def plot_lm2d(lm2d):
|
58 |
+
WH = 512
|
59 |
+
img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
|
60 |
+
|
61 |
+
for i in range(len(lm2d)):
|
62 |
+
x, y = lm2d[i]
|
63 |
+
color = (255,0,0)
|
64 |
+
img = cv2.circle(img, center=(int(x),int(y)), radius=3, color=color, thickness=-1)
|
65 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
66 |
+
for i in range(len(lm2d)):
|
67 |
+
x, y = lm2d[i]
|
68 |
+
img = cv2.putText(img, f"{i}", org=(int(x),int(y)), fontFace=font, fontScale=0.3, color=(255,0,0))
|
69 |
+
return img
|
70 |
+
|
71 |
+
def get_face_rect(lms, h, w):
|
72 |
+
"""
|
73 |
+
lms: [68, 2]
|
74 |
+
h, w: int
|
75 |
+
return: [4,]
|
76 |
+
"""
|
77 |
+
assert len(lms) == 68
|
78 |
+
# min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
|
79 |
+
min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
|
80 |
+
cx = int((min_x+max_x)/2.0)
|
81 |
+
cy = int(lms[27, 1])
|
82 |
+
h_w = int((max_x-cx)*1.5)
|
83 |
+
h_h = int((lms[8, 1]-cy)*1.15)
|
84 |
+
rect_x = cx - h_w
|
85 |
+
rect_y = cy - h_h
|
86 |
+
if rect_x < 0:
|
87 |
+
rect_x = 0
|
88 |
+
if rect_y < 0:
|
89 |
+
rect_y = 0
|
90 |
+
rect_w = min(w-1-rect_x, 2*h_w)
|
91 |
+
rect_h = min(h-1-rect_y, 2*h_h)
|
92 |
+
# rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
|
93 |
+
# rect = [rect_x, rect_y, rect_w, rect_h]
|
94 |
+
rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
|
95 |
+
return rect # this x is width, y is height
|
96 |
+
|
97 |
+
def get_lip_rect(lms, h, w):
|
98 |
+
"""
|
99 |
+
lms: [68, 2]
|
100 |
+
h, w: int
|
101 |
+
return: [4,]
|
102 |
+
"""
|
103 |
+
# this x is width, y is height
|
104 |
+
# for lms, lms[:, 0] is width, lms[:, 1] is height
|
105 |
+
assert len(lms) == 68
|
106 |
+
lips = slice(48, 60)
|
107 |
+
lms = lms[lips]
|
108 |
+
min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
|
109 |
+
min_y, max_y = np.min(lms[:, 1]), np.max(lms[:, 1])
|
110 |
+
cx = int((min_x+max_x)/2.0)
|
111 |
+
cy = int((min_y+max_y)/2.0)
|
112 |
+
h_w = int((max_x-cx)*1.2)
|
113 |
+
h_h = int((max_y-cy)*1.2)
|
114 |
+
|
115 |
+
h_w = max(h_w, h_h)
|
116 |
+
h_h = h_w
|
117 |
+
|
118 |
+
rect_x = cx - h_w
|
119 |
+
rect_y = cy - h_h
|
120 |
+
rect_w = 2*h_w
|
121 |
+
rect_h = 2*h_h
|
122 |
+
if rect_x < 0:
|
123 |
+
rect_x = 0
|
124 |
+
if rect_y < 0:
|
125 |
+
rect_y = 0
|
126 |
+
|
127 |
+
if rect_x + rect_w > w:
|
128 |
+
rect_x = w - rect_w
|
129 |
+
if rect_y + rect_h > h:
|
130 |
+
rect_y = h - rect_h
|
131 |
+
|
132 |
+
rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
|
133 |
+
return rect # this x is width, y is height
|
134 |
+
|
135 |
+
|
136 |
+
# def get_lip_rect(lms, h, w):
|
137 |
+
# """
|
138 |
+
# lms: [68, 2]
|
139 |
+
# h, w: int
|
140 |
+
# return: [4,]
|
141 |
+
# """
|
142 |
+
# assert len(lms) == 68
|
143 |
+
# lips = slice(48, 60)
|
144 |
+
# # this x is width, y is height
|
145 |
+
# xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
|
146 |
+
# ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
|
147 |
+
# # padding to H == W
|
148 |
+
# cx = (xmin + xmax) // 2
|
149 |
+
# cy = (ymin + ymax) // 2
|
150 |
+
# l = max(xmax - xmin, ymax - ymin) // 2
|
151 |
+
# xmin = max(0, cx - l)
|
152 |
+
# xmax = min(h, cx + l)
|
153 |
+
# ymin = max(0, cy - l)
|
154 |
+
# ymax = min(w, cy + l)
|
155 |
+
# lip_rect = [xmin, xmax, ymin, ymax]
|
156 |
+
# return lip_rect
|
157 |
+
|
158 |
+
def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'):
|
159 |
+
"""
|
160 |
+
conds: [b, t=16, h=29]
|
161 |
+
idx: long, time index of the selected frame
|
162 |
+
"""
|
163 |
+
idx = max(0, idx)
|
164 |
+
idx = min(idx, conds.shape[0]-1)
|
165 |
+
smo_half_win_size = smo_win_size//2
|
166 |
+
left_i = idx - smo_half_win_size
|
167 |
+
right_i = idx + (smo_win_size - smo_half_win_size)
|
168 |
+
pad_left, pad_right = 0, 0
|
169 |
+
if left_i < 0:
|
170 |
+
pad_left = -left_i
|
171 |
+
left_i = 0
|
172 |
+
if right_i > conds.shape[0]:
|
173 |
+
pad_right = right_i - conds.shape[0]
|
174 |
+
right_i = conds.shape[0]
|
175 |
+
conds_win = conds[left_i:right_i]
|
176 |
+
if pad_left > 0:
|
177 |
+
if pad_option == 'zero':
|
178 |
+
conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0)
|
179 |
+
elif pad_option == 'edge':
|
180 |
+
edge_value = conds[0][np.newaxis, ...]
|
181 |
+
conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0)
|
182 |
+
else:
|
183 |
+
raise NotImplementedError
|
184 |
+
if pad_right > 0:
|
185 |
+
if pad_option == 'zero':
|
186 |
+
conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0)
|
187 |
+
elif pad_option == 'edge':
|
188 |
+
edge_value = conds[-1][np.newaxis, ...]
|
189 |
+
conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0)
|
190 |
+
else:
|
191 |
+
raise NotImplementedError
|
192 |
+
assert conds_win.shape[0] == smo_win_size
|
193 |
+
return conds_win
|
194 |
+
|
195 |
+
|
196 |
+
def load_processed_data(processed_dir):
|
197 |
+
# load necessary files
|
198 |
+
background_img_name = os.path.join(processed_dir, "bg.jpg")
|
199 |
+
assert os.path.exists(background_img_name)
|
200 |
+
head_img_dir = os.path.join(processed_dir, "head_imgs")
|
201 |
+
torso_img_dir = os.path.join(processed_dir, "inpaint_torso_imgs")
|
202 |
+
gt_img_dir = os.path.join(processed_dir, "gt_imgs")
|
203 |
+
|
204 |
+
hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy")
|
205 |
+
mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy")
|
206 |
+
coeff_npy_name = os.path.join(processed_dir, "coeff_fit_mp.npy")
|
207 |
+
lm2d_npy_name = os.path.join(processed_dir, "lms_2d.npy")
|
208 |
+
|
209 |
+
ret_dict = {}
|
210 |
+
|
211 |
+
ret_dict['bg_img'] = imageio.imread(background_img_name)
|
212 |
+
ret_dict['H'], ret_dict['W'] = ret_dict['bg_img'].shape[:2]
|
213 |
+
ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = face_model.focal, face_model.center, face_model.center
|
214 |
+
|
215 |
+
print("loading lm2d coeff ...")
|
216 |
+
lm2d_arr = np.load(lm2d_npy_name)
|
217 |
+
face_rect_lst = []
|
218 |
+
lip_rect_lst = []
|
219 |
+
for lm2d in lm2d_arr:
|
220 |
+
if len(lm2d) in [468, 478]:
|
221 |
+
lm2d = lm2d[index_lm68_from_lm468]
|
222 |
+
face_rect = get_face_rect(lm2d, ret_dict['H'], ret_dict['W'])
|
223 |
+
lip_rect = get_lip_rect(lm2d, ret_dict['H'], ret_dict['W'])
|
224 |
+
face_rect_lst.append(face_rect)
|
225 |
+
lip_rect_lst.append(lip_rect)
|
226 |
+
face_rects = np.stack(face_rect_lst, axis=0) # [T, 4]
|
227 |
+
|
228 |
+
print("loading fitted 3dmm coeff ...")
|
229 |
+
coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist()
|
230 |
+
identity_arr = coeff_dict['id']
|
231 |
+
exp_arr = coeff_dict['exp']
|
232 |
+
ret_dict['id'] = identity_arr
|
233 |
+
ret_dict['exp'] = exp_arr
|
234 |
+
euler_arr = ret_dict['euler'] = coeff_dict['euler']
|
235 |
+
trans_arr = ret_dict['trans'] = coeff_dict['trans']
|
236 |
+
print("calculating lm3d ...")
|
237 |
+
idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy().reshape([-1, 68*3])
|
238 |
+
len_motion = len(idexp_lm3d_arr)
|
239 |
+
video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0)
|
240 |
+
video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0)
|
241 |
+
ret_dict['idexp_lm3d'] = idexp_lm3d_arr
|
242 |
+
ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean
|
243 |
+
ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std
|
244 |
+
|
245 |
+
# now we convert the euler_trans from deep3d convention to adnerf convention
|
246 |
+
eulers = torch.FloatTensor(euler_arr)
|
247 |
+
trans = torch.FloatTensor(trans_arr)
|
248 |
+
rots = face_model.compute_rotation(eulers) # rotation matrix is a better intermediate for convention-transplan than euler
|
249 |
+
|
250 |
+
# handle the camera pose to geneface's convention
|
251 |
+
trans[:, 2] = 10 - trans[:, 2] # 抵消fit阶段的to_camera操作,即trans[...,2] = 10 - trans[...,2]
|
252 |
+
rots = rots.permute(0, 2, 1)
|
253 |
+
trans[:, 2] = - trans[:,2] # 因为intrinsic proj不同
|
254 |
+
# below is the NeRF camera preprocessing strategy, see `save_transforms` in data_util/process.py
|
255 |
+
trans = trans / 10.0
|
256 |
+
rots_inv = rots.permute(0, 2, 1)
|
257 |
+
trans_inv = - torch.bmm(rots_inv, trans.unsqueeze(2))
|
258 |
+
|
259 |
+
pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat([len_motion, 1, 1]) # [T, 4, 4]
|
260 |
+
pose[:, :3, :3] = rots_inv
|
261 |
+
pose[:, :3, 3] = trans_inv[:, :, 0]
|
262 |
+
c2w_transform_matrices = pose.numpy()
|
263 |
+
|
264 |
+
# process the audio features used for postnet training
|
265 |
+
print("loading hubert ...")
|
266 |
+
hubert_features = np.load(hubert_npy_name)
|
267 |
+
print("loading Mel and F0 ...")
|
268 |
+
mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist()
|
269 |
+
|
270 |
+
ret_dict['hubert'] = hubert_features
|
271 |
+
ret_dict['mel'] = mel_f0_features['mel']
|
272 |
+
ret_dict['f0'] = mel_f0_features['f0']
|
273 |
+
|
274 |
+
# obtaining train samples
|
275 |
+
frame_indices = list(range(len_motion))
|
276 |
+
num_train = len_motion // 11 * 10
|
277 |
+
train_indices = frame_indices[:num_train]
|
278 |
+
val_indices = frame_indices[num_train:]
|
279 |
+
|
280 |
+
for split in ['train', 'val']:
|
281 |
+
if split == 'train':
|
282 |
+
indices = train_indices
|
283 |
+
samples = []
|
284 |
+
ret_dict['train_samples'] = samples
|
285 |
+
elif split == 'val':
|
286 |
+
indices = val_indices
|
287 |
+
samples = []
|
288 |
+
ret_dict['val_samples'] = samples
|
289 |
+
|
290 |
+
for idx in indices:
|
291 |
+
sample = {}
|
292 |
+
sample['idx'] = idx
|
293 |
+
sample['head_img_fname'] = os.path.join(head_img_dir,f"{idx:08d}.png")
|
294 |
+
sample['torso_img_fname'] = os.path.join(torso_img_dir,f"{idx:08d}.png")
|
295 |
+
sample['gt_img_fname'] = os.path.join(gt_img_dir,f"{idx:08d}.jpg")
|
296 |
+
# assert os.path.exists(sample['head_img_fname']) and os.path.exists(sample['torso_img_fname']) and os.path.exists(sample['gt_img_fname'])
|
297 |
+
sample['face_rect'] = face_rects[idx]
|
298 |
+
sample['lip_rect'] = lip_rect_lst[idx]
|
299 |
+
sample['c2w'] = c2w_transform_matrices[idx]
|
300 |
+
samples.append(sample)
|
301 |
+
return ret_dict
|
302 |
+
|
303 |
+
|
304 |
+
class Binarizer:
|
305 |
+
def __init__(self):
|
306 |
+
self.data_dir = 'data/'
|
307 |
+
|
308 |
+
def parse(self, video_id):
|
309 |
+
processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id)
|
310 |
+
binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id)
|
311 |
+
out_fname = os.path.join(binary_dir, "trainval_dataset.npy")
|
312 |
+
os.makedirs(binary_dir, exist_ok=True)
|
313 |
+
ret = load_processed_data(processed_dir)
|
314 |
+
mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy')
|
315 |
+
mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist()
|
316 |
+
ret.update(mel_f0_dict)
|
317 |
+
np.save(out_fname, ret, allow_pickle=True)
|
318 |
+
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == '__main__':
|
322 |
+
from argparse import ArgumentParser
|
323 |
+
parser = ArgumentParser()
|
324 |
+
parser.add_argument('--video_id', type=str, default='May', help='')
|
325 |
+
args = parser.parse_args()
|
326 |
+
### Process Single Long Audio for NeRF dataset
|
327 |
+
video_id = args.video_id
|
328 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
329 |
+
camera_distance=10, focal=1015)
|
330 |
+
face_model.to("cpu")
|
331 |
+
face3d_helper = Face3DHelper()
|
332 |
+
|
333 |
+
binarizer = Binarizer()
|
334 |
+
binarizer.parse(video_id)
|
335 |
+
print(f"Binarization for {video_id} Done!")
|
data_gen/runs/nerf/process_guide.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 温馨提示:第一次执行可以先一步步跑完下面的命令行,把环境跑通后,之后可以直接运行同目录的run.sh,一键完成下面的所有步骤。
|
2 |
+
|
3 |
+
# Step0. 将视频Crop到512x512分辨率,25FPS,确保每一帧都有目标人脸
|
4 |
+
```
|
5 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 data/raw/videos/${VIDEO_ID}_512.mp4
|
6 |
+
mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
|
7 |
+
mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
|
8 |
+
```
|
9 |
+
# step1: 提取音频特征, 如mel, f0, hubuert, esperanto
|
10 |
+
```
|
11 |
+
export CUDA_VISIBLE_DEVICES=0
|
12 |
+
export VIDEO_ID=May
|
13 |
+
mkdir -p data/processed/videos/${VIDEO_ID}
|
14 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 data/processed/videos/${VIDEO_ID}/aud.wav
|
15 |
+
python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
|
16 |
+
python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
|
17 |
+
```
|
18 |
+
|
19 |
+
# Step2. 提取图片
|
20 |
+
```
|
21 |
+
export VIDEO_ID=May
|
22 |
+
export CUDA_VISIBLE_DEVICES=0
|
23 |
+
mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
|
24 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
|
25 |
+
python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
|
26 |
+
```
|
27 |
+
|
28 |
+
# Step3. 提取lm2d_mediapipe
|
29 |
+
### 提取2D landmark用于之后Fit 3DMM
|
30 |
+
### num_workers是本机上的CPU worker数量;total_process是使用的机器数;process_id是本机的编号
|
31 |
+
|
32 |
+
```
|
33 |
+
export VIDEO_ID=May
|
34 |
+
python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
|
35 |
+
```
|
36 |
+
|
37 |
+
# Step3. fit 3dmm
|
38 |
+
```
|
39 |
+
export VIDEO_ID=May
|
40 |
+
export CUDA_VISIBLE_DEVICES=0
|
41 |
+
python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
|
42 |
+
```
|
43 |
+
|
44 |
+
# Step4. Binarize
|
45 |
+
```
|
46 |
+
export VIDEO_ID=May
|
47 |
+
python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
|
48 |
+
```
|
49 |
+
可以看到在`data/binary/videos/Mayssss`目录下得到了数据集。
|
data_gen/runs/nerf/run.sh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# usage: CUDA_VISIBLE_DEVICES=0 bash data_gen/runs/nerf/run.sh <VIDEO_ID>
|
2 |
+
# please place video to data/raw/videos/${VIDEO_ID}.mp4
|
3 |
+
VIDEO_ID=$1
|
4 |
+
echo Processing $VIDEO_ID
|
5 |
+
|
6 |
+
echo Resizing the video to 512x512
|
7 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y data/raw/videos/${VIDEO_ID}_512.mp4
|
8 |
+
mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
|
9 |
+
mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
|
10 |
+
echo Done
|
11 |
+
echo The old video is moved to data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
|
12 |
+
|
13 |
+
echo mkdir -p data/processed/videos/${VIDEO_ID}
|
14 |
+
mkdir -p data/processed/videos/${VIDEO_ID}
|
15 |
+
echo Done
|
16 |
+
|
17 |
+
# extract audio file from the training video
|
18 |
+
echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
|
19 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
|
20 |
+
echo Done
|
21 |
+
|
22 |
+
# extract hubert_mel_f0 from audio
|
23 |
+
echo python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
|
24 |
+
python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
|
25 |
+
echo python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
|
26 |
+
python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
|
27 |
+
echo Done
|
28 |
+
|
29 |
+
# extract segment images
|
30 |
+
echo mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
|
31 |
+
mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
|
32 |
+
echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
|
33 |
+
ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
|
34 |
+
echo Done
|
35 |
+
|
36 |
+
echo python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
|
37 |
+
python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
|
38 |
+
echo Done
|
39 |
+
|
40 |
+
echo python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
|
41 |
+
python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
|
42 |
+
echo Done
|
43 |
+
|
44 |
+
pkill -f void*
|
45 |
+
echo python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
|
46 |
+
python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
|
47 |
+
echo Done
|
48 |
+
|
49 |
+
echo python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
|
50 |
+
python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
|
51 |
+
echo Done
|
data_gen/utils/mp_feature_extractors/face_landmarker.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mediapipe as mp
|
2 |
+
from mediapipe.tasks import python
|
3 |
+
from mediapipe.tasks.python import vision
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
# simplified mediapipe ldm at https://github.com/k-m-irfan/simplified_mediapipe_face_landmarks
|
10 |
+
index_lm141_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [468,469,470,471,472] + [473,474,475,476,477] + [64,4,294]
|
11 |
+
# lm141 without iris
|
12 |
+
index_lm131_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [64,4,294]
|
13 |
+
|
14 |
+
# face alignment lm68
|
15 |
+
index_lm68_from_lm478 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
|
16 |
+
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
|
17 |
+
# used for weights for key parts
|
18 |
+
unmatch_mask_from_lm478 = [ 93, 127, 132, 234, 323, 356, 361, 454]
|
19 |
+
index_eye_from_lm478 = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
|
20 |
+
index_innerlip_from_lm478 = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
|
21 |
+
index_outerlip_from_lm478 = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
|
22 |
+
index_withinmouth_from_lm478 = [76, 62] + [184, 183, 74, 72, 73, 41, 72, 38, 11, 12, 302, 268, 303, 271, 304, 272, 408, 407] + [292, 306] + [325, 307, 319, 320, 403, 404, 316, 315, 15, 16, 86, 85, 179, 180, 89, 90, 96, 77]
|
23 |
+
index_mouth_from_lm478 = index_innerlip_from_lm478 + index_outerlip_from_lm478 + index_withinmouth_from_lm478
|
24 |
+
|
25 |
+
index_yaw_from_lm68 = list(range(0, 17))
|
26 |
+
index_brow_from_lm68 = list(range(17, 27))
|
27 |
+
index_nose_from_lm68 = list(range(27, 36))
|
28 |
+
index_eye_from_lm68 = list(range(36, 48))
|
29 |
+
index_mouth_from_lm68 = list(range(48, 68))
|
30 |
+
|
31 |
+
|
32 |
+
def read_video_to_frames(video_name):
|
33 |
+
frames = []
|
34 |
+
cap = cv2.VideoCapture(video_name)
|
35 |
+
while cap.isOpened():
|
36 |
+
ret, frame_bgr = cap.read()
|
37 |
+
if frame_bgr is None:
|
38 |
+
break
|
39 |
+
frames.append(frame_bgr)
|
40 |
+
frames = np.stack(frames)
|
41 |
+
frames = np.flip(frames, -1) # BGR ==> RGB
|
42 |
+
return frames
|
43 |
+
|
44 |
+
class MediapipeLandmarker:
|
45 |
+
def __init__(self):
|
46 |
+
model_path = 'data_gen/utils/mp_feature_extractors/face_landmarker.task'
|
47 |
+
if not os.path.exists(model_path):
|
48 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
49 |
+
print("downloading face_landmarker model from mediapipe...")
|
50 |
+
model_url = 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task'
|
51 |
+
os.system(f"wget {model_url}")
|
52 |
+
os.system(f"mv face_landmarker.task {model_path}")
|
53 |
+
print("download success")
|
54 |
+
base_options = python.BaseOptions(model_asset_path=model_path)
|
55 |
+
self.image_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
|
56 |
+
running_mode=vision.RunningMode.IMAGE, # IMAGE, VIDEO, LIVE_STREAM
|
57 |
+
num_faces=1)
|
58 |
+
self.video_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
|
59 |
+
running_mode=vision.RunningMode.VIDEO, # IMAGE, VIDEO, LIVE_STREAM
|
60 |
+
num_faces=1)
|
61 |
+
|
62 |
+
def extract_lm478_from_img_name(self, img_name):
|
63 |
+
img = cv2.imread(img_name)
|
64 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
65 |
+
img_lm478 = self.extract_lm478_from_img(img)
|
66 |
+
return img_lm478
|
67 |
+
|
68 |
+
def extract_lm478_from_img(self, img):
|
69 |
+
img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
|
70 |
+
frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=img.astype(np.uint8))
|
71 |
+
img_face_landmarker_result = img_landmarker.detect(image=frame)
|
72 |
+
img_ldm_i = img_face_landmarker_result.face_landmarks[0]
|
73 |
+
img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
|
74 |
+
H, W, _ = img.shape
|
75 |
+
img_lm478 = np.array(img_face_landmarks)[:, :2] * np.array([W, H]).reshape([1,2]) # [478, 2]
|
76 |
+
return img_lm478
|
77 |
+
|
78 |
+
def extract_lm478_from_video_name(self, video_name, fps=25, anti_smooth_factor=2):
|
79 |
+
frames = read_video_to_frames(video_name)
|
80 |
+
img_lm478, vid_lm478 = self.extract_lm478_from_frames(frames, fps, anti_smooth_factor)
|
81 |
+
return img_lm478, vid_lm478
|
82 |
+
|
83 |
+
def extract_lm478_from_frames(self, frames, fps=25, anti_smooth_factor=20):
|
84 |
+
"""
|
85 |
+
frames: RGB, uint8
|
86 |
+
anti_smooth_factor: float, 对video模式的interval进行修改, 1代表无修改, 越大越接近image mode
|
87 |
+
"""
|
88 |
+
img_mpldms = []
|
89 |
+
vid_mpldms = []
|
90 |
+
img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
|
91 |
+
vid_landmarker = vision.FaceLandmarker.create_from_options(self.video_mode_options)
|
92 |
+
|
93 |
+
for i in range(len(frames)):
|
94 |
+
frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=frames[i].astype(np.uint8))
|
95 |
+
img_face_landmarker_result = img_landmarker.detect(image=frame)
|
96 |
+
vid_face_landmarker_result = vid_landmarker.detect_for_video(image=frame, timestamp_ms=int((1000/fps)*anti_smooth_factor*i))
|
97 |
+
try:
|
98 |
+
img_ldm_i = img_face_landmarker_result.face_landmarks[0]
|
99 |
+
vid_ldm_i = vid_face_landmarker_result.face_landmarks[0]
|
100 |
+
except:
|
101 |
+
print(f"Warning: failed detect ldm in idx={i}, use previous frame results.")
|
102 |
+
img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
|
103 |
+
vid_face_landmarks = np.array([[l.x, l.y, l.z] for l in vid_ldm_i])
|
104 |
+
img_mpldms.append(img_face_landmarks)
|
105 |
+
vid_mpldms.append(vid_face_landmarks)
|
106 |
+
img_lm478 = np.stack(img_mpldms)[..., :2]
|
107 |
+
vid_lm478 = np.stack(vid_mpldms)[..., :2]
|
108 |
+
bs, H, W, _ = frames.shape
|
109 |
+
img_lm478 = np.array(img_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
|
110 |
+
vid_lm478 = np.array(vid_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
|
111 |
+
return img_lm478, vid_lm478
|
112 |
+
|
113 |
+
def combine_vid_img_lm478_to_lm68(self, img_lm478, vid_lm478):
|
114 |
+
img_lm68 = img_lm478[:, index_lm68_from_lm478]
|
115 |
+
vid_lm68 = vid_lm478[:, index_lm68_from_lm478]
|
116 |
+
combined_lm68 = copy.deepcopy(img_lm68)
|
117 |
+
combined_lm68[:, index_yaw_from_lm68] = vid_lm68[:, index_yaw_from_lm68]
|
118 |
+
combined_lm68[:, index_brow_from_lm68] = vid_lm68[:, index_brow_from_lm68]
|
119 |
+
combined_lm68[:, index_nose_from_lm68] = vid_lm68[:, index_nose_from_lm68]
|
120 |
+
return combined_lm68
|
121 |
+
|
122 |
+
def combine_vid_img_lm478_to_lm478(self, img_lm478, vid_lm478):
|
123 |
+
combined_lm478 = copy.deepcopy(vid_lm478)
|
124 |
+
combined_lm478[:, index_mouth_from_lm478] = img_lm478[:, index_mouth_from_lm478]
|
125 |
+
combined_lm478[:, index_eye_from_lm478] = img_lm478[:, index_eye_from_lm478]
|
126 |
+
return combined_lm478
|
127 |
+
|
128 |
+
if __name__ == '__main__':
|
129 |
+
landmarker = MediapipeLandmarker()
|
130 |
+
ret = landmarker.extract_lm478_from_video_name("00000.mp4")
|
data_gen/utils/mp_feature_extractors/face_landmarker.task
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
|
3 |
+
size 3758596
|
data_gen/utils/mp_feature_extractors/mp_segmenter.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
import tqdm
|
5 |
+
import mediapipe as mp
|
6 |
+
import torch
|
7 |
+
from mediapipe.tasks import python
|
8 |
+
from mediapipe.tasks.python import vision
|
9 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
|
10 |
+
from utils.commons.tensor_utils import convert_to_np
|
11 |
+
from sklearn.neighbors import NearestNeighbors
|
12 |
+
|
13 |
+
def scatter_np(condition_img, classSeg=5):
|
14 |
+
# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
|
15 |
+
batch, c, height, width = condition_img.shape
|
16 |
+
# if height != label_size[0] or width != label_size[1]:
|
17 |
+
# condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
|
18 |
+
input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_)
|
19 |
+
# input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
|
20 |
+
np.put_along_axis(input_label, condition_img, 1, 1)
|
21 |
+
return input_label
|
22 |
+
|
23 |
+
def scatter(condition_img, classSeg=19):
|
24 |
+
# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
|
25 |
+
batch, c, height, width = condition_img.size()
|
26 |
+
# if height != label_size[0] or width != label_size[1]:
|
27 |
+
# condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
|
28 |
+
input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device)
|
29 |
+
# input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
|
30 |
+
return input_label.scatter_(1, condition_img.long(), 1)
|
31 |
+
|
32 |
+
def encode_segmap_mask_to_image(segmap):
|
33 |
+
# rgb
|
34 |
+
_,h,w = segmap.shape
|
35 |
+
encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255
|
36 |
+
colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
|
37 |
+
for i, color in enumerate(colors):
|
38 |
+
mask = segmap[i].astype(int)
|
39 |
+
index = np.where(mask != 0)
|
40 |
+
encoded_img[index[0], index[1], :] = np.array(color)
|
41 |
+
return encoded_img.astype(np.uint8)
|
42 |
+
|
43 |
+
def decode_segmap_mask_from_image(encoded_img):
|
44 |
+
# rgb
|
45 |
+
colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
|
46 |
+
bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
|
47 |
+
hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
|
48 |
+
body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255)
|
49 |
+
face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
|
50 |
+
clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0)
|
51 |
+
others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
|
52 |
+
segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0)
|
53 |
+
return segmap.astype(np.uint8)
|
54 |
+
|
55 |
+
def read_video_frame(video_name, frame_id):
|
56 |
+
# https://blog.csdn.net/bby1987/article/details/108923361
|
57 |
+
# frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数
|
58 |
+
# fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率
|
59 |
+
# width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度
|
60 |
+
# height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度
|
61 |
+
# pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置
|
62 |
+
# video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置
|
63 |
+
# pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0
|
64 |
+
# video_capture.release()
|
65 |
+
vr = cv2.VideoCapture(video_name)
|
66 |
+
vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
|
67 |
+
_, frame = vr.read()
|
68 |
+
return frame
|
69 |
+
|
70 |
+
def decode_segmap_mask_from_segmap_video_frame(video_frame):
|
71 |
+
# video_frame: 0~255 BGR, obtained by read_video_frame
|
72 |
+
def assign_values(array):
|
73 |
+
remainder = array % 40 # 计算数组中每个值与40的余数
|
74 |
+
assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder))
|
75 |
+
return assigned_values
|
76 |
+
segmap = video_frame.mean(-1)
|
77 |
+
segmap = assign_values(segmap) // 40 # [H, W] with value 0~5
|
78 |
+
segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
79 |
+
return segmap.astype(np.uint8)
|
80 |
+
|
81 |
+
def extract_background(img_lst, segmap_lst=None):
|
82 |
+
"""
|
83 |
+
img_lst: list of rgb ndarray
|
84 |
+
"""
|
85 |
+
# only use 1/20 images
|
86 |
+
num_frames = len(img_lst)
|
87 |
+
img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1]
|
88 |
+
|
89 |
+
if segmap_lst is not None:
|
90 |
+
segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1]
|
91 |
+
assert len(img_lst) == len(segmap_lst)
|
92 |
+
# get H/W
|
93 |
+
h, w = img_lst[0].shape[:2]
|
94 |
+
|
95 |
+
# nearest neighbors
|
96 |
+
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
|
97 |
+
distss = []
|
98 |
+
for idx, img in enumerate(img_lst):
|
99 |
+
if segmap_lst is not None:
|
100 |
+
segmap = segmap_lst[idx]
|
101 |
+
else:
|
102 |
+
segmap = seg_model._cal_seg_map(img)
|
103 |
+
bg = (segmap[0]).astype(bool)
|
104 |
+
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
|
105 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
106 |
+
dists, _ = nbrs.kneighbors(all_xys)
|
107 |
+
distss.append(dists)
|
108 |
+
|
109 |
+
distss = np.stack(distss)
|
110 |
+
max_dist = np.max(distss, 0)
|
111 |
+
max_id = np.argmax(distss, 0)
|
112 |
+
|
113 |
+
bc_pixs = max_dist > 10 # 5
|
114 |
+
bc_pixs_id = np.nonzero(bc_pixs)
|
115 |
+
bc_ids = max_id[bc_pixs]
|
116 |
+
|
117 |
+
num_pixs = distss.shape[1]
|
118 |
+
imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
|
119 |
+
|
120 |
+
bg_img = np.zeros((h*w, 3), dtype=np.uint8)
|
121 |
+
bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
|
122 |
+
bg_img = bg_img.reshape(h, w, 3)
|
123 |
+
|
124 |
+
max_dist = max_dist.reshape(h, w)
|
125 |
+
bc_pixs = max_dist > 10 # 5
|
126 |
+
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
127 |
+
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
128 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
129 |
+
distances, indices = nbrs.kneighbors(bg_xys)
|
130 |
+
bg_fg_xys = fg_xys[indices[:, 0]]
|
131 |
+
bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
132 |
+
return bg_img
|
133 |
+
|
134 |
+
|
135 |
+
class MediapipeSegmenter:
|
136 |
+
def __init__(self):
|
137 |
+
model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite'
|
138 |
+
if not os.path.exists(model_path):
|
139 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
140 |
+
print("downloading segmenter model from mediapipe...")
|
141 |
+
os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite")
|
142 |
+
os.system(f"mv selfie_multiclass_256x256.tflite {model_path}")
|
143 |
+
print("download success")
|
144 |
+
base_options = python.BaseOptions(model_asset_path=model_path)
|
145 |
+
self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True)
|
146 |
+
self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True)
|
147 |
+
|
148 |
+
def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True, debug_fill=False):
|
149 |
+
segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter
|
150 |
+
assert return_onehot_mask or return_segmap_image # you should at least return one
|
151 |
+
segmap_masks = []
|
152 |
+
segmap_images = []
|
153 |
+
for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."):
|
154 |
+
# for i in range(len(imgs)):
|
155 |
+
img = imgs[i]
|
156 |
+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
157 |
+
out = segmenter.segment_for_video(mp_image, 40 * i)
|
158 |
+
segmap = out.category_mask.numpy_view().copy() # [H, W]
|
159 |
+
if debug_fill:
|
160 |
+
# print(f'segmap {segmap}')
|
161 |
+
for x in range(-80 + 1, 0):
|
162 |
+
for y in range(200, 350):
|
163 |
+
segmap[x][y] = 4
|
164 |
+
|
165 |
+
if return_onehot_mask:
|
166 |
+
segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
167 |
+
segmap_masks.append(segmap_mask)
|
168 |
+
if return_segmap_image:
|
169 |
+
segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
|
170 |
+
segmap_image = (segmap_image * 40).astype(np.uint8)
|
171 |
+
segmap_images.append(segmap_image)
|
172 |
+
|
173 |
+
if return_onehot_mask and return_segmap_image:
|
174 |
+
return segmap_masks, segmap_images
|
175 |
+
elif return_onehot_mask:
|
176 |
+
return segmap_masks
|
177 |
+
elif return_segmap_image:
|
178 |
+
return segmap_images
|
179 |
+
|
180 |
+
def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True):
|
181 |
+
"""
|
182 |
+
segmenter: vision.ImageSegmenter.create_from_options(options)
|
183 |
+
img: numpy, [H, W, 3], 0~255
|
184 |
+
segmap: [C, H, W]
|
185 |
+
0 - background
|
186 |
+
1 - hair
|
187 |
+
2 - body-skin
|
188 |
+
3 - face-skin
|
189 |
+
4 - clothes
|
190 |
+
5 - others (accessories)
|
191 |
+
"""
|
192 |
+
assert img.ndim == 3
|
193 |
+
segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
|
194 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
195 |
+
out = segmenter.segment(image)
|
196 |
+
segmap = out.category_mask.numpy_view().copy() # [H, W]
|
197 |
+
if return_onehot_mask:
|
198 |
+
segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
|
199 |
+
return segmap
|
200 |
+
|
201 |
+
def _seg_out_img_with_segmap(self, img, segmap, mode='head'):
|
202 |
+
"""
|
203 |
+
img: [h,w,c], img is in 0~255, np
|
204 |
+
"""
|
205 |
+
#
|
206 |
+
img = copy.deepcopy(img)
|
207 |
+
if mode == 'head':
|
208 |
+
selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others
|
209 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
210 |
+
# selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5
|
211 |
+
elif mode == 'person':
|
212 |
+
selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5
|
213 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
214 |
+
elif mode == 'torso':
|
215 |
+
selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5
|
216 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
217 |
+
elif mode == 'torso_with_bg':
|
218 |
+
selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5
|
219 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
220 |
+
elif mode == 'bg':
|
221 |
+
selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background
|
222 |
+
img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
|
223 |
+
elif mode == 'full':
|
224 |
+
pass
|
225 |
+
else:
|
226 |
+
raise NotImplementedError()
|
227 |
+
return img, selected_mask
|
228 |
+
|
229 |
+
def _seg_out_img(self, img, segmenter=None, mode='head'):
|
230 |
+
"""
|
231 |
+
imgs [H, W, 3] 0-255
|
232 |
+
return : person_img [B, 3, H, W]
|
233 |
+
"""
|
234 |
+
segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
|
235 |
+
segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W]
|
236 |
+
return self._seg_out_img_with_segmap(img, segmap, mode=mode)
|
237 |
+
|
238 |
+
def seg_out_imgs(self, img, mode='head'):
|
239 |
+
"""
|
240 |
+
api for pytorch img, -1~1
|
241 |
+
img: [B, 3, H, W], -1~1
|
242 |
+
"""
|
243 |
+
device = img.device
|
244 |
+
img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
|
245 |
+
img = ((img + 1) * 127.5).astype(np.uint8)
|
246 |
+
img_lst = [copy.deepcopy(img[i]) for i in range(len(img))]
|
247 |
+
out_lst = []
|
248 |
+
for im in img_lst:
|
249 |
+
out = self._seg_out_img(im, mode=mode)
|
250 |
+
out_lst.append(out)
|
251 |
+
seg_imgs = np.stack(out_lst) # [B, H, W, 3]
|
252 |
+
seg_imgs = (seg_imgs - 127.5) / 127.5
|
253 |
+
seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device)
|
254 |
+
return seg_imgs
|
255 |
+
|
256 |
+
if __name__ == '__main__':
|
257 |
+
import imageio, cv2, tqdm
|
258 |
+
import torchshow as ts
|
259 |
+
img = imageio.imread("1.png")
|
260 |
+
img = cv2.resize(img, (512,512))
|
261 |
+
|
262 |
+
seg_model = MediapipeSegmenter()
|
263 |
+
img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2)
|
264 |
+
img = (img-127.5)/127.5
|
265 |
+
out = seg_model.seg_out_imgs(img, 'torso')
|
266 |
+
ts.save(out,"torso.png")
|
267 |
+
out = seg_model.seg_out_imgs(img, 'head')
|
268 |
+
ts.save(out,"head.png")
|
269 |
+
out = seg_model.seg_out_imgs(img, 'bg')
|
270 |
+
ts.save(out,"bg.png")
|
271 |
+
img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
|
272 |
+
img = ((img + 1) * 127.5).astype(np.uint8)
|
273 |
+
bg = extract_background(img)
|
274 |
+
ts.save(bg,"bg2.png")
|
data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
|
3 |
+
size 16371837
|
data_gen/utils/path_converter.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
class PathConverter():
|
5 |
+
def __init__(self):
|
6 |
+
self.prefixs = {
|
7 |
+
"vid": "/video/",
|
8 |
+
"gt": "/gt_imgs/",
|
9 |
+
"head": "/head_imgs/",
|
10 |
+
"torso": "/torso_imgs/",
|
11 |
+
"person": "/person_imgs/",
|
12 |
+
"torso_with_bg": "/torso_with_bg_imgs/",
|
13 |
+
"single_bg": "/bg_img/",
|
14 |
+
"bg": "/bg_imgs/",
|
15 |
+
"segmaps": "/segmaps/",
|
16 |
+
"inpaint_torso": "/inpaint_torso_imgs/",
|
17 |
+
"com": "/com_imgs/",
|
18 |
+
"inpaint_torso_with_com_bg": "/inpaint_torso_with_com_bg_imgs/",
|
19 |
+
}
|
20 |
+
|
21 |
+
def to(self, path: str, old_pattern: str, new_pattern: str):
|
22 |
+
return path.replace(self.prefixs[old_pattern], self.prefixs[new_pattern], 1)
|
23 |
+
|
24 |
+
pc = PathConverter()
|
data_gen/utils/process_audio/extract_hubert.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2Processor, HubertModel
|
2 |
+
import soundfile as sf
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from utils.commons.hparams import set_hparams, hparams
|
7 |
+
|
8 |
+
|
9 |
+
wav2vec2_processor = None
|
10 |
+
hubert_model = None
|
11 |
+
|
12 |
+
|
13 |
+
def get_hubert_from_16k_wav(wav_16k_name):
|
14 |
+
speech_16k, _ = sf.read(wav_16k_name)
|
15 |
+
hubert = get_hubert_from_16k_speech(speech_16k)
|
16 |
+
return hubert
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def get_hubert_from_16k_speech(speech, device="cuda:0"):
|
20 |
+
global hubert_model, wav2vec2_processor
|
21 |
+
local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
|
22 |
+
if hubert_model is None:
|
23 |
+
print("Loading the HuBERT Model...")
|
24 |
+
if os.path.exists(local_path):
|
25 |
+
hubert_model = HubertModel.from_pretrained(local_path)
|
26 |
+
else:
|
27 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
28 |
+
hubert_model = hubert_model.to(device)
|
29 |
+
if wav2vec2_processor is None:
|
30 |
+
print("Loading the Wav2Vec2 Processor...")
|
31 |
+
if os.path.exists(local_path):
|
32 |
+
wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
|
33 |
+
else:
|
34 |
+
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
35 |
+
|
36 |
+
if speech.ndim ==2:
|
37 |
+
speech = speech[:, 0] # [T, 2] ==> [T,]
|
38 |
+
|
39 |
+
input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
|
40 |
+
input_values_all = input_values_all.to(device)
|
41 |
+
# For long audio sequence, due to the memory limitation, we cannot process them in one run
|
42 |
+
# HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
|
43 |
+
# Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
|
44 |
+
# So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
|
45 |
+
# We have the equation to calculate out time step: T = floor((t-k)/s)
|
46 |
+
# To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
|
47 |
+
# The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
|
48 |
+
kernel = 400
|
49 |
+
stride = 320
|
50 |
+
clip_length = stride * 1000
|
51 |
+
num_iter = input_values_all.shape[1] // clip_length
|
52 |
+
expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
|
53 |
+
res_lst = []
|
54 |
+
for i in range(num_iter):
|
55 |
+
if i == 0:
|
56 |
+
start_idx = 0
|
57 |
+
end_idx = clip_length - stride + kernel
|
58 |
+
else:
|
59 |
+
start_idx = clip_length * i
|
60 |
+
end_idx = start_idx + (clip_length - stride + kernel)
|
61 |
+
input_values = input_values_all[:, start_idx: end_idx]
|
62 |
+
hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
63 |
+
res_lst.append(hidden_states[0])
|
64 |
+
if num_iter > 0:
|
65 |
+
input_values = input_values_all[:, clip_length * num_iter:]
|
66 |
+
else:
|
67 |
+
input_values = input_values_all
|
68 |
+
|
69 |
+
if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
|
70 |
+
hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
71 |
+
res_lst.append(hidden_states[0])
|
72 |
+
ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
|
73 |
+
|
74 |
+
assert abs(ret.shape[0] - expected_T) <= 1
|
75 |
+
if ret.shape[0] < expected_T: # if skipping the last short
|
76 |
+
ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
|
77 |
+
else:
|
78 |
+
ret = ret[:expected_T]
|
79 |
+
|
80 |
+
return ret
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
from argparse import ArgumentParser
|
85 |
+
parser = ArgumentParser()
|
86 |
+
parser.add_argument('--video_id', type=str, default='May', help='')
|
87 |
+
args = parser.parse_args()
|
88 |
+
### Process Single Long Audio for NeRF dataset
|
89 |
+
person_id = args.video_id
|
90 |
+
wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
|
91 |
+
hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
|
92 |
+
speech_16k, _ = sf.read(wav_16k_name)
|
93 |
+
hubert_hidden = get_hubert_from_16k_speech(speech_16k)
|
94 |
+
np.save(hubert_npy_name, hubert_hidden.detach().numpy())
|
95 |
+
print(f"Saved at {hubert_npy_name}")
|
data_gen/utils/process_audio/extract_mel_f0.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import tqdm
|
6 |
+
import librosa
|
7 |
+
import parselmouth
|
8 |
+
from utils.commons.pitch_utils import f0_to_coarse
|
9 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
10 |
+
from utils.commons.os_utils import multiprocess_glob
|
11 |
+
from utils.audio.io import save_wav
|
12 |
+
|
13 |
+
from moviepy.editor import VideoFileClip
|
14 |
+
from utils.commons.hparams import hparams, set_hparams
|
15 |
+
|
16 |
+
def resample_wav(wav_name, out_name, sr=16000):
|
17 |
+
wav_raw, sr = librosa.core.load(wav_name, sr=sr)
|
18 |
+
save_wav(wav_raw, out_name, sr)
|
19 |
+
|
20 |
+
def split_wav(mp4_name, wav_name=None):
|
21 |
+
if wav_name is None:
|
22 |
+
wav_name = mp4_name.replace(".mp4", ".wav").replace("/video/", "/audio/")
|
23 |
+
if os.path.exists(wav_name):
|
24 |
+
return wav_name
|
25 |
+
os.makedirs(os.path.dirname(wav_name), exist_ok=True)
|
26 |
+
|
27 |
+
video = VideoFileClip(mp4_name,verbose=False)
|
28 |
+
dur = video.duration
|
29 |
+
audio = video.audio
|
30 |
+
assert audio is not None
|
31 |
+
audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
|
32 |
+
return wav_name
|
33 |
+
|
34 |
+
def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
|
35 |
+
'''compute right padding (final frame) or both sides padding (first and final frames)
|
36 |
+
'''
|
37 |
+
assert pad_sides in (1, 2)
|
38 |
+
# return int(fsize // 2)
|
39 |
+
pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
40 |
+
if pad_sides == 1:
|
41 |
+
return 0, pad
|
42 |
+
else:
|
43 |
+
return pad // 2, pad // 2 + pad % 2
|
44 |
+
|
45 |
+
def extract_mel_from_fname(wav_path,
|
46 |
+
fft_size=512,
|
47 |
+
hop_size=320,
|
48 |
+
win_length=512,
|
49 |
+
window="hann",
|
50 |
+
num_mels=80,
|
51 |
+
fmin=80,
|
52 |
+
fmax=7600,
|
53 |
+
eps=1e-6,
|
54 |
+
sample_rate=16000,
|
55 |
+
min_level_db=-100):
|
56 |
+
if isinstance(wav_path, str):
|
57 |
+
wav, _ = librosa.core.load(wav_path, sr=sample_rate)
|
58 |
+
else:
|
59 |
+
wav = wav_path
|
60 |
+
|
61 |
+
# get amplitude spectrogram
|
62 |
+
x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
|
63 |
+
win_length=win_length, window=window, center=False)
|
64 |
+
spc = np.abs(x_stft) # (n_bins, T)
|
65 |
+
|
66 |
+
# get mel basis
|
67 |
+
fmin = 0 if fmin == -1 else fmin
|
68 |
+
fmax = sample_rate / 2 if fmax == -1 else fmax
|
69 |
+
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
70 |
+
mel = mel_basis @ spc
|
71 |
+
|
72 |
+
mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
|
73 |
+
mel = mel.T
|
74 |
+
|
75 |
+
l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
|
76 |
+
wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
|
77 |
+
|
78 |
+
return wav.T, mel
|
79 |
+
|
80 |
+
def extract_f0_from_wav_and_mel(wav, mel,
|
81 |
+
hop_size=320,
|
82 |
+
audio_sample_rate=16000,
|
83 |
+
):
|
84 |
+
time_step = hop_size / audio_sample_rate * 1000
|
85 |
+
f0_min = 80
|
86 |
+
f0_max = 750
|
87 |
+
f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
|
88 |
+
time_step=time_step / 1000, voicing_threshold=0.6,
|
89 |
+
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
|
90 |
+
|
91 |
+
delta_l = len(mel) - len(f0)
|
92 |
+
assert np.abs(delta_l) <= 8
|
93 |
+
if delta_l > 0:
|
94 |
+
f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
|
95 |
+
f0 = f0[:len(mel)]
|
96 |
+
pitch_coarse = f0_to_coarse(f0)
|
97 |
+
return f0, pitch_coarse
|
98 |
+
|
99 |
+
|
100 |
+
def extract_mel_f0_from_fname(wav_name=None, out_name=None):
|
101 |
+
try:
|
102 |
+
out_name = wav_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
|
103 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
104 |
+
|
105 |
+
wav, mel = extract_mel_from_fname(wav_name)
|
106 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
107 |
+
out_dict = {
|
108 |
+
"mel": mel, # [T, 80]
|
109 |
+
"f0": f0,
|
110 |
+
}
|
111 |
+
np.save(out_name, out_dict)
|
112 |
+
except Exception as e:
|
113 |
+
print(e)
|
114 |
+
|
115 |
+
def extract_mel_f0_from_video_name(mp4_name, wav_name=None, out_name=None):
|
116 |
+
if mp4_name.endswith(".mp4"):
|
117 |
+
wav_name = split_wav(mp4_name, wav_name)
|
118 |
+
if out_name is None:
|
119 |
+
out_name = mp4_name.replace(".mp4", "_mel_f0.npy").replace("/video/", "/mel_f0/")
|
120 |
+
elif mp4_name.endswith(".wav"):
|
121 |
+
wav_name = mp4_name
|
122 |
+
if out_name is None:
|
123 |
+
out_name = mp4_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
|
124 |
+
|
125 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
126 |
+
|
127 |
+
wav, mel = extract_mel_from_fname(wav_name)
|
128 |
+
|
129 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
130 |
+
out_dict = {
|
131 |
+
"mel": mel, # [T, 80]
|
132 |
+
"f0": f0,
|
133 |
+
}
|
134 |
+
np.save(out_name, out_dict)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
from argparse import ArgumentParser
|
139 |
+
parser = ArgumentParser()
|
140 |
+
parser.add_argument('--video_id', type=str, default='May', help='')
|
141 |
+
args = parser.parse_args()
|
142 |
+
### Process Single Long Audio for NeRF dataset
|
143 |
+
person_id = args.video_id
|
144 |
+
|
145 |
+
wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
|
146 |
+
out_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy"
|
147 |
+
extract_mel_f0_from_video_name(wav_16k_name, out_name)
|
148 |
+
print(f"Saved at {out_name}")
|
data_gen/utils/process_audio/resample_audio_to_16k.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
from utils.commons.os_utils import multiprocess_glob
|
3 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def extract_wav16k_job(audio_name:str):
|
7 |
+
out_path = audio_name.replace("/audio_raw/","/audio/",1)
|
8 |
+
assert out_path != audio_name # prevent inplace
|
9 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
10 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
11 |
+
|
12 |
+
cmd = f'{ffmpeg_path} -i {audio_name} -ar 16000 -v quiet -y {out_path}'
|
13 |
+
os.system(cmd)
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
import argparse, glob, tqdm, random
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--aud_dir", default='/home/tiger/datasets/raw/CMLR/audio_raw/')
|
19 |
+
parser.add_argument("--ds_name", default='CMLR')
|
20 |
+
parser.add_argument("--num_workers", default=64, type=int)
|
21 |
+
parser.add_argument("--process_id", default=0, type=int)
|
22 |
+
parser.add_argument("--total_process", default=1, type=int)
|
23 |
+
args = parser.parse_args()
|
24 |
+
print(f"args {args}")
|
25 |
+
|
26 |
+
aud_dir = args.aud_dir
|
27 |
+
ds_name = args.ds_name
|
28 |
+
if ds_name in ['CMLR']:
|
29 |
+
aud_name_pattern = os.path.join(aud_dir, "*/*/*.wav")
|
30 |
+
aud_names = multiprocess_glob(aud_name_pattern)
|
31 |
+
else:
|
32 |
+
raise NotImplementedError()
|
33 |
+
aud_names = sorted(aud_names)
|
34 |
+
print(f"total audio number : {len(aud_names)}")
|
35 |
+
print(f"first {aud_names[0]} last {aud_names[-1]}")
|
36 |
+
# exit()
|
37 |
+
process_id = args.process_id
|
38 |
+
total_process = args.total_process
|
39 |
+
if total_process > 1:
|
40 |
+
assert process_id <= total_process -1
|
41 |
+
num_samples_per_process = len(aud_names) // total_process
|
42 |
+
if process_id == total_process:
|
43 |
+
aud_names = aud_names[process_id * num_samples_per_process : ]
|
44 |
+
else:
|
45 |
+
aud_names = aud_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
46 |
+
|
47 |
+
for i, res in multiprocess_run_tqdm(extract_wav16k_job, aud_names, num_workers=args.num_workers, desc="resampling videos"):
|
48 |
+
pass
|
49 |
+
|
data_gen/utils/process_image/extract_lm2d.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import glob
|
6 |
+
import cv2
|
7 |
+
import tqdm
|
8 |
+
import numpy as np
|
9 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
|
10 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
+
|
14 |
+
import random
|
15 |
+
random.seed(42)
|
16 |
+
|
17 |
+
import pickle
|
18 |
+
import json
|
19 |
+
import gzip
|
20 |
+
from typing import Any
|
21 |
+
|
22 |
+
def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
|
23 |
+
if is_json:
|
24 |
+
if is_gzip:
|
25 |
+
with gzip.open(filename, "r", encoding="utf-8") as f:
|
26 |
+
loaded_object = json.load(f)
|
27 |
+
return loaded_object
|
28 |
+
else:
|
29 |
+
with open(filename, "r", encoding="utf-8") as f:
|
30 |
+
loaded_object = json.load(f)
|
31 |
+
return loaded_object
|
32 |
+
else:
|
33 |
+
if is_gzip:
|
34 |
+
with gzip.open(filename, "rb") as f:
|
35 |
+
loaded_object = pickle.load(f)
|
36 |
+
return loaded_object
|
37 |
+
else:
|
38 |
+
with open(filename, "rb") as f:
|
39 |
+
loaded_object = pickle.load(f)
|
40 |
+
return loaded_object
|
41 |
+
|
42 |
+
def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
|
43 |
+
if is_json:
|
44 |
+
if is_gzip:
|
45 |
+
with gzip.open(filename, "w", encoding="utf-8") as f:
|
46 |
+
json.dump(content, f)
|
47 |
+
else:
|
48 |
+
with open(filename, "w", encoding="utf-8") as f:
|
49 |
+
json.dump(content, f)
|
50 |
+
else:
|
51 |
+
if is_gzip:
|
52 |
+
with gzip.open(filename, "wb") as f:
|
53 |
+
pickle.dump(content, f)
|
54 |
+
else:
|
55 |
+
with open(filename, "wb") as f:
|
56 |
+
pickle.dump(content, f)
|
57 |
+
|
58 |
+
face_landmarker = None
|
59 |
+
|
60 |
+
def extract_lms_mediapipe_job(img):
|
61 |
+
if img is None:
|
62 |
+
return None
|
63 |
+
global face_landmarker
|
64 |
+
if face_landmarker is None:
|
65 |
+
face_landmarker = MediapipeLandmarker()
|
66 |
+
lm478 = face_landmarker.extract_lm478_from_img(img)
|
67 |
+
return lm478
|
68 |
+
|
69 |
+
def extract_landmark_job(img_name):
|
70 |
+
try:
|
71 |
+
# if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
|
72 |
+
# print(1)
|
73 |
+
# input()
|
74 |
+
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
|
75 |
+
if os.path.exists(out_name):
|
76 |
+
print("out exists, skip...")
|
77 |
+
return
|
78 |
+
try:
|
79 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
80 |
+
except:
|
81 |
+
pass
|
82 |
+
img = cv2.imread(img_name)[:,:,::-1]
|
83 |
+
|
84 |
+
if img is not None:
|
85 |
+
lm468 = extract_lms_mediapipe_job(img)
|
86 |
+
if lm468 is not None:
|
87 |
+
np.save(out_name, lm468)
|
88 |
+
# print("Hahaha, solve one item!!!")
|
89 |
+
except Exception as e:
|
90 |
+
print(e)
|
91 |
+
pass
|
92 |
+
|
93 |
+
def out_exist_job(img_name):
|
94 |
+
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
|
95 |
+
if os.path.exists(out_name):
|
96 |
+
return None
|
97 |
+
else:
|
98 |
+
return img_name
|
99 |
+
|
100 |
+
# def get_todo_img_names(img_names):
|
101 |
+
# todo_img_names = []
|
102 |
+
# for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
|
103 |
+
# if res is not None:
|
104 |
+
# todo_img_names.append(res)
|
105 |
+
# return todo_img_names
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
import argparse, glob, tqdm, random
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
|
112 |
+
parser.add_argument("--ds_name", default='FFHQ')
|
113 |
+
parser.add_argument("--num_workers", default=64, type=int)
|
114 |
+
parser.add_argument("--process_id", default=0, type=int)
|
115 |
+
parser.add_argument("--total_process", default=1, type=int)
|
116 |
+
parser.add_argument("--reset", action='store_true')
|
117 |
+
parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
|
118 |
+
parser.add_argument("--load_img_names", action="store_true")
|
119 |
+
|
120 |
+
args = parser.parse_args()
|
121 |
+
print(f"args {args}")
|
122 |
+
img_dir = args.img_dir
|
123 |
+
img_names_file = os.path.join(img_dir, args.img_names_file)
|
124 |
+
if args.load_img_names:
|
125 |
+
img_names = load_file(img_names_file)
|
126 |
+
print(f"load image names from {img_names_file}")
|
127 |
+
else:
|
128 |
+
if args.ds_name == 'FFHQ_MV':
|
129 |
+
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
|
130 |
+
img_names1 = glob.glob(img_name_pattern1)
|
131 |
+
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
|
132 |
+
img_names2 = glob.glob(img_name_pattern2)
|
133 |
+
img_names = img_names1 + img_names2
|
134 |
+
img_names = sorted(img_names)
|
135 |
+
elif args.ds_name == 'FFHQ':
|
136 |
+
img_name_pattern = os.path.join(img_dir, "*.png")
|
137 |
+
img_names = glob.glob(img_name_pattern)
|
138 |
+
img_names = sorted(img_names)
|
139 |
+
elif args.ds_name == "PanoHeadGen":
|
140 |
+
# img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
|
141 |
+
img_name_patterns = ["ref/*/*.png"]
|
142 |
+
img_names = []
|
143 |
+
for img_name_pattern in img_name_patterns:
|
144 |
+
img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
|
145 |
+
img_names_part = glob.glob(img_name_pattern_full)
|
146 |
+
img_names.extend(img_names_part)
|
147 |
+
img_names = sorted(img_names)
|
148 |
+
|
149 |
+
# save image names
|
150 |
+
if not args.load_img_names:
|
151 |
+
save_file(img_names_file, img_names)
|
152 |
+
print(f"save image names in {img_names_file}")
|
153 |
+
|
154 |
+
print(f"total images number: {len(img_names)}")
|
155 |
+
|
156 |
+
|
157 |
+
process_id = args.process_id
|
158 |
+
total_process = args.total_process
|
159 |
+
if total_process > 1:
|
160 |
+
assert process_id <= total_process -1
|
161 |
+
num_samples_per_process = len(img_names) // total_process
|
162 |
+
if process_id == total_process:
|
163 |
+
img_names = img_names[process_id * num_samples_per_process : ]
|
164 |
+
else:
|
165 |
+
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
166 |
+
|
167 |
+
# if not args.reset:
|
168 |
+
# img_names = get_todo_img_names(img_names)
|
169 |
+
|
170 |
+
|
171 |
+
print(f"todo_image {img_names[:10]}")
|
172 |
+
print(f"processing images number in this process: {len(img_names)}")
|
173 |
+
# print(f"todo images number: {len(img_names)}")
|
174 |
+
# input()
|
175 |
+
# exit()
|
176 |
+
|
177 |
+
if args.num_workers == 1:
|
178 |
+
index = 0
|
179 |
+
for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
|
180 |
+
try:
|
181 |
+
extract_landmark_job(img_name)
|
182 |
+
except Exception as e:
|
183 |
+
print(e)
|
184 |
+
pass
|
185 |
+
if index % max(1, int(len(img_names) * 0.003)) == 0:
|
186 |
+
print(f"processed {index} / {len(img_names)}")
|
187 |
+
sys.stdout.flush()
|
188 |
+
index += 1
|
189 |
+
else:
|
190 |
+
for i, res in multiprocess_run_tqdm(
|
191 |
+
extract_landmark_job, img_names,
|
192 |
+
num_workers=args.num_workers,
|
193 |
+
desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
|
194 |
+
# if index % max(1, int(len(img_names) * 0.003)) == 0:
|
195 |
+
print(f"processed {i+1} / {len(img_names)}")
|
196 |
+
sys.stdout.flush()
|
197 |
+
print(f"Root {args.process_id}: Finished extracting.")
|
data_gen/utils/process_image/extract_segment_imgs.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
from utils.commons.tensor_utils import convert_to_np
|
10 |
+
import torch
|
11 |
+
import mediapipe as mp
|
12 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
13 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
|
14 |
+
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
|
15 |
+
seg_model = MediapipeSegmenter()
|
16 |
+
|
17 |
+
|
18 |
+
def extract_segment_job(img_name):
|
19 |
+
try:
|
20 |
+
img = cv2.imread(img_name)
|
21 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
22 |
+
|
23 |
+
segmap = seg_model._cal_seg_map(img)
|
24 |
+
bg_img = extract_background([img], [segmap])
|
25 |
+
out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
|
26 |
+
save_rgb_image_to_path(bg_img, out_img_name)
|
27 |
+
|
28 |
+
com_img = img.copy()
|
29 |
+
bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
|
30 |
+
com_img[bg_part] = bg_img[bg_part]
|
31 |
+
out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
|
32 |
+
save_rgb_image_to_path(com_img, out_img_name)
|
33 |
+
|
34 |
+
for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
|
35 |
+
out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
|
36 |
+
out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
|
37 |
+
out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
|
38 |
+
try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
|
39 |
+
except: pass
|
40 |
+
cv2.imwrite(out_img_name, out_img)
|
41 |
+
|
42 |
+
inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
|
43 |
+
out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
|
44 |
+
save_rgb_image_to_path(inpaint_torso_img, out_img_name)
|
45 |
+
inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
|
46 |
+
out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
|
47 |
+
save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
|
48 |
+
return 0
|
49 |
+
except Exception as e:
|
50 |
+
print(e)
|
51 |
+
return 1
|
52 |
+
|
53 |
+
def out_exist_job(img_name):
|
54 |
+
out_name1 = img_name.replace("/images_512/", "/head_imgs/")
|
55 |
+
out_name2 = img_name.replace("/images_512/", "/com_imgs/")
|
56 |
+
out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
|
57 |
+
|
58 |
+
if os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
|
59 |
+
return None
|
60 |
+
else:
|
61 |
+
return img_name
|
62 |
+
|
63 |
+
def get_todo_img_names(img_names):
|
64 |
+
todo_img_names = []
|
65 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
|
66 |
+
if res is not None:
|
67 |
+
todo_img_names.append(res)
|
68 |
+
return todo_img_names
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
import argparse, glob, tqdm, random
|
73 |
+
parser = argparse.ArgumentParser()
|
74 |
+
parser.add_argument("--img_dir", default='./images_512')
|
75 |
+
# parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
|
76 |
+
parser.add_argument("--ds_name", default='FFHQ')
|
77 |
+
parser.add_argument("--num_workers", default=1, type=int)
|
78 |
+
parser.add_argument("--seed", default=0, type=int)
|
79 |
+
parser.add_argument("--process_id", default=0, type=int)
|
80 |
+
parser.add_argument("--total_process", default=1, type=int)
|
81 |
+
parser.add_argument("--reset", action='store_true')
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
img_dir = args.img_dir
|
85 |
+
if args.ds_name == 'FFHQ_MV':
|
86 |
+
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
|
87 |
+
img_names1 = glob.glob(img_name_pattern1)
|
88 |
+
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
|
89 |
+
img_names2 = glob.glob(img_name_pattern2)
|
90 |
+
img_names = img_names1 + img_names2
|
91 |
+
elif args.ds_name == 'FFHQ':
|
92 |
+
img_name_pattern = os.path.join(img_dir, "*.png")
|
93 |
+
img_names = glob.glob(img_name_pattern)
|
94 |
+
|
95 |
+
img_names = sorted(img_names)
|
96 |
+
random.seed(args.seed)
|
97 |
+
random.shuffle(img_names)
|
98 |
+
|
99 |
+
process_id = args.process_id
|
100 |
+
total_process = args.total_process
|
101 |
+
if total_process > 1:
|
102 |
+
assert process_id <= total_process -1
|
103 |
+
num_samples_per_process = len(img_names) // total_process
|
104 |
+
if process_id == total_process:
|
105 |
+
img_names = img_names[process_id * num_samples_per_process : ]
|
106 |
+
else:
|
107 |
+
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
108 |
+
|
109 |
+
if not args.reset:
|
110 |
+
img_names = get_todo_img_names(img_names)
|
111 |
+
print(f"todo images number: {len(img_names)}")
|
112 |
+
|
113 |
+
for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
|
114 |
+
pass
|
data_gen/utils/process_image/fit_3dmm_landmark.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.core.numeric import require
|
2 |
+
from numpy.lib.function_base import quantile
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import copy
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import cv2
|
11 |
+
import argparse
|
12 |
+
import tqdm
|
13 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
14 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
|
15 |
+
|
16 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
17 |
+
import pickle
|
18 |
+
|
19 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
20 |
+
camera_distance=10, focal=1015, keypoint_mode='mediapipe')
|
21 |
+
face_model.to("cuda")
|
22 |
+
|
23 |
+
|
24 |
+
index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
|
25 |
+
33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
|
26 |
+
|
27 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
28 |
+
|
29 |
+
LAMBDA_REG_ID = 0.3
|
30 |
+
LAMBDA_REG_EXP = 0.05
|
31 |
+
|
32 |
+
def save_file(name, content):
|
33 |
+
with open(name, "wb") as f:
|
34 |
+
pickle.dump(content, f)
|
35 |
+
|
36 |
+
def load_file(name):
|
37 |
+
with open(name, "rb") as f:
|
38 |
+
content = pickle.load(f)
|
39 |
+
return content
|
40 |
+
|
41 |
+
def cal_lan_loss_mp(proj_lan, gt_lan):
|
42 |
+
# [B, 68, 2]
|
43 |
+
loss = (proj_lan - gt_lan).pow(2)
|
44 |
+
# loss = (proj_lan - gt_lan).abs()
|
45 |
+
unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
|
46 |
+
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
|
47 |
+
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
|
48 |
+
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
|
49 |
+
weights = torch.ones_like(loss)
|
50 |
+
weights[:, eye] = 5
|
51 |
+
weights[:, inner_lip] = 2
|
52 |
+
weights[:, outer_lip] = 2
|
53 |
+
weights[:, unmatch_mask] = 0
|
54 |
+
loss = loss * weights
|
55 |
+
return torch.mean(loss)
|
56 |
+
|
57 |
+
def cal_lan_loss(proj_lan, gt_lan):
|
58 |
+
# [B, 68, 2]
|
59 |
+
loss = (proj_lan - gt_lan)** 2
|
60 |
+
# use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
|
61 |
+
weights = torch.zeros_like(loss)
|
62 |
+
weights = torch.ones_like(loss)
|
63 |
+
weights[:, 36:48, :] = 3 # eye 12 points
|
64 |
+
weights[:, -8:, :] = 3 # inner lip 8 points
|
65 |
+
weights[:, 28:31, :] = 3 # nose 3 points
|
66 |
+
loss = loss * weights
|
67 |
+
return torch.mean(loss)
|
68 |
+
|
69 |
+
def set_requires_grad(tensor_list):
|
70 |
+
for tensor in tensor_list:
|
71 |
+
tensor.requires_grad = True
|
72 |
+
|
73 |
+
def read_video_to_frames(img_name):
|
74 |
+
frames = []
|
75 |
+
cap = cv2.VideoCapture(img_name)
|
76 |
+
while cap.isOpened():
|
77 |
+
ret, frame_bgr = cap.read()
|
78 |
+
if frame_bgr is None:
|
79 |
+
break
|
80 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
81 |
+
frames.append(frame_rgb)
|
82 |
+
return np.stack(frames)
|
83 |
+
|
84 |
+
@torch.enable_grad()
|
85 |
+
def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True):
|
86 |
+
img = cv2.imread(img_name)
|
87 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
88 |
+
img_h, img_w = img.shape[0], img.shape[0]
|
89 |
+
assert img_h == img_w
|
90 |
+
num_frames = 1
|
91 |
+
|
92 |
+
lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy")
|
93 |
+
if lm_name.endswith('_lms.npy') and os.path.exists(lm_name):
|
94 |
+
lms = np.load(lm_name)
|
95 |
+
else:
|
96 |
+
# print("lms_2d file not found, try to extract it from image...")
|
97 |
+
try:
|
98 |
+
landmarker = MediapipeLandmarker()
|
99 |
+
lms = landmarker.extract_lm478_from_img_name(img_name)
|
100 |
+
# lms = landmarker.extract_lm478_from_img(img)
|
101 |
+
except Exception as e:
|
102 |
+
print(e)
|
103 |
+
return
|
104 |
+
if lms is None:
|
105 |
+
print("get None lms_2d, please check whether each frame has one head, exiting...")
|
106 |
+
return
|
107 |
+
lms = lms[:468].reshape([468,2])
|
108 |
+
lms = torch.FloatTensor(lms).to(device=device)
|
109 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
110 |
+
|
111 |
+
if keypoint_mode == 'mediapipe':
|
112 |
+
cal_lan_loss_fn = cal_lan_loss_mp
|
113 |
+
out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy")
|
114 |
+
else:
|
115 |
+
cal_lan_loss_fn = cal_lan_loss
|
116 |
+
out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy")
|
117 |
+
try:
|
118 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
119 |
+
except:
|
120 |
+
pass
|
121 |
+
|
122 |
+
id_dim, exp_dim = 80, 64
|
123 |
+
sel_ids = np.arange(0, num_frames, 40)
|
124 |
+
sel_num = sel_ids.shape[0]
|
125 |
+
arg_focal = face_model.focal
|
126 |
+
|
127 |
+
h = w = face_model.center * 2
|
128 |
+
img_scale_factor = img_h / h
|
129 |
+
lms /= img_scale_factor
|
130 |
+
cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device)
|
131 |
+
|
132 |
+
id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True)
|
133 |
+
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
134 |
+
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
135 |
+
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
136 |
+
|
137 |
+
focal_length = lms.new_zeros(1, requires_grad=True)
|
138 |
+
focal_length.data += arg_focal
|
139 |
+
|
140 |
+
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
141 |
+
|
142 |
+
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
|
143 |
+
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
|
144 |
+
|
145 |
+
# 其他参数初始化,先训练euler和trans
|
146 |
+
for _ in range(200):
|
147 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
148 |
+
id_para, exp_para, euler_angle, trans)
|
149 |
+
loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
|
150 |
+
loss = loss_lan
|
151 |
+
optimizer_frame.zero_grad()
|
152 |
+
loss.backward()
|
153 |
+
optimizer_frame.step()
|
154 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
155 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
156 |
+
|
157 |
+
for param_group in optimizer_frame.param_groups:
|
158 |
+
param_group['lr'] = 0.1
|
159 |
+
|
160 |
+
# "jointly roughly training id exp euler trans"
|
161 |
+
for _ in range(200):
|
162 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
163 |
+
id_para, exp_para, euler_angle, trans)
|
164 |
+
loss_lan = cal_lan_loss_fn(
|
165 |
+
proj_geo[:, :, :2], lms.detach())
|
166 |
+
loss_regid = torch.mean(id_para*id_para) # 正则化
|
167 |
+
loss_regexp = torch.mean(exp_para * exp_para)
|
168 |
+
|
169 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
|
170 |
+
optimizer_idexp.zero_grad()
|
171 |
+
optimizer_frame.zero_grad()
|
172 |
+
loss.backward()
|
173 |
+
optimizer_idexp.step()
|
174 |
+
optimizer_frame.step()
|
175 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
|
176 |
+
# print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
177 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
178 |
+
|
179 |
+
# start fine training, intialize from the roughly trained results
|
180 |
+
id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
|
181 |
+
id_para_.data = id_para.data.clone()
|
182 |
+
id_para = id_para_
|
183 |
+
exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
184 |
+
exp_para_.data = exp_para.data.clone()
|
185 |
+
exp_para = exp_para_
|
186 |
+
euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
187 |
+
euler_angle_.data = euler_angle.data.clone()
|
188 |
+
euler_angle = euler_angle_
|
189 |
+
trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
190 |
+
trans_.data = trans.data.clone()
|
191 |
+
trans = trans_
|
192 |
+
|
193 |
+
batch_size = 1
|
194 |
+
|
195 |
+
# "fine fitting the 3DMM in batches"
|
196 |
+
for i in range(int((num_frames-1)/batch_size+1)):
|
197 |
+
if (i+1)*batch_size > num_frames:
|
198 |
+
start_n = num_frames-batch_size
|
199 |
+
sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
|
200 |
+
else:
|
201 |
+
start_n = i*batch_size
|
202 |
+
sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
|
203 |
+
sel_lms = lms[sel_ids]
|
204 |
+
|
205 |
+
sel_id_para = id_para.new_zeros(
|
206 |
+
(batch_size, id_dim), requires_grad=True)
|
207 |
+
sel_id_para.data = id_para[sel_ids].clone()
|
208 |
+
sel_exp_para = exp_para.new_zeros(
|
209 |
+
(batch_size, exp_dim), requires_grad=True)
|
210 |
+
sel_exp_para.data = exp_para[sel_ids].clone()
|
211 |
+
sel_euler_angle = euler_angle.new_zeros(
|
212 |
+
(batch_size, 3), requires_grad=True)
|
213 |
+
sel_euler_angle.data = euler_angle[sel_ids].clone()
|
214 |
+
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
215 |
+
sel_trans.data = trans[sel_ids].clone()
|
216 |
+
|
217 |
+
set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
|
218 |
+
optimizer_cur_batch = torch.optim.Adam(
|
219 |
+
[sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
|
220 |
+
|
221 |
+
for j in range(50):
|
222 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
223 |
+
sel_id_para, sel_exp_para, sel_euler_angle, sel_trans)
|
224 |
+
loss_lan = cal_lan_loss_fn(
|
225 |
+
proj_geo[:, :, :2], lms.unsqueeze(0).detach())
|
226 |
+
|
227 |
+
loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
|
228 |
+
loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
|
229 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
|
230 |
+
optimizer_cur_batch.zero_grad()
|
231 |
+
loss.backward()
|
232 |
+
optimizer_cur_batch.step()
|
233 |
+
print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}")
|
234 |
+
id_para[sel_ids].data = sel_id_para.data.clone()
|
235 |
+
exp_para[sel_ids].data = sel_exp_para.data.clone()
|
236 |
+
euler_angle[sel_ids].data = sel_euler_angle.data.clone()
|
237 |
+
trans[sel_ids].data = sel_trans.data.clone()
|
238 |
+
|
239 |
+
coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
|
240 |
+
'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
|
241 |
+
if save:
|
242 |
+
np.save(out_name, coeff_dict, allow_pickle=True)
|
243 |
+
|
244 |
+
if debug:
|
245 |
+
import imageio
|
246 |
+
debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg")
|
247 |
+
try: os.makedirs(os.path.dirname(debug_name), exist_ok=True)
|
248 |
+
except: pass
|
249 |
+
proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
|
250 |
+
lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
|
251 |
+
lm68s = lm68s * img_scale_factor
|
252 |
+
lms = lms * img_scale_factor
|
253 |
+
lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
|
254 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
255 |
+
lm68s = lm68s.astype(int)
|
256 |
+
lm68s = lm68s.reshape([-1,2])
|
257 |
+
lms = lms.cpu().numpy().astype(int).reshape([-1,2])
|
258 |
+
for lm in lm68s:
|
259 |
+
img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1)
|
260 |
+
for gt_lm in lms:
|
261 |
+
img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1)
|
262 |
+
imageio.imwrite(debug_name, img)
|
263 |
+
print(f"debug img saved at {debug_name}")
|
264 |
+
return coeff_dict
|
265 |
+
|
266 |
+
def out_exist_job(vid_name):
|
267 |
+
out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy")
|
268 |
+
# if os.path.exists(out_name) or not os.path.exists(lms_name):
|
269 |
+
if os.path.exists(out_name):
|
270 |
+
return None
|
271 |
+
else:
|
272 |
+
return vid_name
|
273 |
+
|
274 |
+
def get_todo_img_names(img_names):
|
275 |
+
todo_img_names = []
|
276 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16):
|
277 |
+
if res is not None:
|
278 |
+
todo_img_names.append(res)
|
279 |
+
return todo_img_names
|
280 |
+
|
281 |
+
|
282 |
+
if __name__ == '__main__':
|
283 |
+
import argparse, glob, tqdm
|
284 |
+
parser = argparse.ArgumentParser()
|
285 |
+
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
|
286 |
+
parser.add_argument("--ds_name", default='FFHQ')
|
287 |
+
parser.add_argument("--seed", default=0, type=int)
|
288 |
+
parser.add_argument("--process_id", default=0, type=int)
|
289 |
+
parser.add_argument("--total_process", default=1, type=int)
|
290 |
+
parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
|
291 |
+
parser.add_argument("--debug", action='store_true')
|
292 |
+
parser.add_argument("--reset", action='store_true')
|
293 |
+
parser.add_argument("--device", default="cuda:0", type=str)
|
294 |
+
parser.add_argument("--output_log", action='store_true')
|
295 |
+
parser.add_argument("--load_names", action="store_true")
|
296 |
+
|
297 |
+
args = parser.parse_args()
|
298 |
+
img_dir = args.img_dir
|
299 |
+
load_names = args.load_names
|
300 |
+
|
301 |
+
print(f"args {args}")
|
302 |
+
|
303 |
+
if args.ds_name == 'single_img':
|
304 |
+
img_names = [img_dir]
|
305 |
+
else:
|
306 |
+
img_names_path = os.path.join(img_dir, "img_dir.pkl")
|
307 |
+
if os.path.exists(img_names_path) and load_names:
|
308 |
+
print(f"loading vid names from {img_names_path}")
|
309 |
+
img_names = load_file(img_names_path)
|
310 |
+
else:
|
311 |
+
if args.ds_name == 'FFHQ_MV':
|
312 |
+
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
|
313 |
+
img_names1 = glob.glob(img_name_pattern1)
|
314 |
+
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
|
315 |
+
img_names2 = glob.glob(img_name_pattern2)
|
316 |
+
img_names = img_names1 + img_names2
|
317 |
+
img_names = sorted(img_names)
|
318 |
+
elif args.ds_name == 'FFHQ':
|
319 |
+
img_name_pattern = os.path.join(img_dir, "*.png")
|
320 |
+
img_names = glob.glob(img_name_pattern)
|
321 |
+
img_names = sorted(img_names)
|
322 |
+
elif args.ds_name == "PanoHeadGen":
|
323 |
+
img_name_patterns = ["ref/*/*.png"]
|
324 |
+
img_names = []
|
325 |
+
for img_name_pattern in img_name_patterns:
|
326 |
+
img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
|
327 |
+
img_names_part = glob.glob(img_name_pattern_full)
|
328 |
+
img_names.extend(img_names_part)
|
329 |
+
img_names = sorted(img_names)
|
330 |
+
print(f"saving image names to {img_names_path}")
|
331 |
+
save_file(img_names_path, img_names)
|
332 |
+
|
333 |
+
# import random
|
334 |
+
# random.seed(args.seed)
|
335 |
+
# random.shuffle(img_names)
|
336 |
+
|
337 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
338 |
+
camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
|
339 |
+
face_model.to(torch.device(args.device))
|
340 |
+
|
341 |
+
process_id = args.process_id
|
342 |
+
total_process = args.total_process
|
343 |
+
if total_process > 1:
|
344 |
+
assert process_id <= total_process -1 and process_id >= 0
|
345 |
+
num_samples_per_process = len(img_names) // total_process
|
346 |
+
if process_id == total_process:
|
347 |
+
img_names = img_names[process_id * num_samples_per_process : ]
|
348 |
+
else:
|
349 |
+
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
350 |
+
print(f"image names number (before fileter): {len(img_names)}")
|
351 |
+
|
352 |
+
|
353 |
+
if not args.reset:
|
354 |
+
img_names = get_todo_img_names(img_names)
|
355 |
+
|
356 |
+
print(f"image names number (after fileter): {len(img_names)}")
|
357 |
+
for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."):
|
358 |
+
img_name = img_names[i]
|
359 |
+
try:
|
360 |
+
fit_3dmm_for_a_image(img_name, args.debug, device=args.device)
|
361 |
+
except Exception as e:
|
362 |
+
print(img_name, e)
|
363 |
+
if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0:
|
364 |
+
print(f"process {process_id}: {i + 1} / {len(img_names)} done")
|
365 |
+
sys.stdout.flush()
|
366 |
+
sys.stderr.flush()
|
367 |
+
|
368 |
+
print(f"process {process_id}: fitting 3dmm all done")
|
369 |
+
|
data_gen/utils/process_video/euler2quaterion.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import numba
|
5 |
+
from scipy.spatial.transform import Rotation as R
|
6 |
+
|
7 |
+
def euler2quaterion(euler, use_radian=True):
|
8 |
+
"""
|
9 |
+
euler: np.array, [batch, 3]
|
10 |
+
return: the quaterion, np.array, [batch, 4]
|
11 |
+
"""
|
12 |
+
r = R.from_euler('xyz',euler, degrees=not use_radian)
|
13 |
+
return r.as_quat()
|
14 |
+
|
15 |
+
def quaterion2euler(quat, use_radian=True):
|
16 |
+
"""
|
17 |
+
quat: np.array, [batch, 4]
|
18 |
+
return: the euler, np.array, [batch, 3]
|
19 |
+
"""
|
20 |
+
r = R.from_quat(quat)
|
21 |
+
return r.as_euler('xyz', degrees=not use_radian)
|
22 |
+
|
23 |
+
def rot2quaterion(rot):
|
24 |
+
r = R.from_matrix(rot)
|
25 |
+
return r.as_quat()
|
26 |
+
|
27 |
+
def quaterion2rot(quat):
|
28 |
+
r = R.from_quat(quat)
|
29 |
+
return r.as_matrix()
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3])
|
33 |
+
q = euler2quaterion(euler, use_radian=False)
|
34 |
+
e = quaterion2euler(q, use_radian=False)
|
35 |
+
print(" ")
|
data_gen/utils/process_video/extract_blink.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from data_util.face3d_helper import Face3DHelper
|
3 |
+
from utils.commons.tensor_utils import convert_to_tensor
|
4 |
+
|
5 |
+
def polygon_area(x, y):
|
6 |
+
"""
|
7 |
+
x: [T, K=6]
|
8 |
+
y: [T, K=6]
|
9 |
+
return: [T,]
|
10 |
+
"""
|
11 |
+
x_ = x - x.mean(axis=-1, keepdims=True)
|
12 |
+
y_ = y - y.mean(axis=-1, keepdims=True)
|
13 |
+
correction = x_[:,-1] * y_[:,0] - y_[:,-1]* x_[:,0]
|
14 |
+
main_area = (x_[:,:-1] * y_[:,1:]).sum(axis=-1) - (y_[:,:-1] * x_[:,1:]).sum(axis=-1)
|
15 |
+
return 0.5 * np.abs(main_area + correction)
|
16 |
+
|
17 |
+
def get_eye_area_percent(id, exp, face3d_helper):
|
18 |
+
id = convert_to_tensor(id)
|
19 |
+
exp = convert_to_tensor(exp)
|
20 |
+
cano_lm3d = face3d_helper.reconstruct_cano_lm3d(id, exp)
|
21 |
+
cano_lm2d = (cano_lm3d[..., :2] + 1) / 2
|
22 |
+
lms = cano_lm2d.cpu().numpy()
|
23 |
+
eyes_left = slice(36, 42)
|
24 |
+
eyes_right = slice(42, 48)
|
25 |
+
area_left = polygon_area(lms[:, eyes_left, 0], lms[:, eyes_left, 1])
|
26 |
+
area_right = polygon_area(lms[:, eyes_right, 0], lms[:, eyes_right, 1])
|
27 |
+
# area percentage of two eyes of the whole image...
|
28 |
+
area_percent = (area_left + area_right) / 1 * 100 # recommend threshold is 0.25%
|
29 |
+
return area_percent # [T,]
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
import numpy as np
|
34 |
+
import imageio
|
35 |
+
import cv2
|
36 |
+
import torch
|
37 |
+
from data_gen.utils.process_video.extract_lm2d import extract_lms_mediapipe_job, read_video_to_frames, index_lm68_from_lm468
|
38 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
39 |
+
from data_util.face3d_helper import Face3DHelper
|
40 |
+
|
41 |
+
face3d_helper = Face3DHelper()
|
42 |
+
video_name = 'data/raw/videos/May_10s.mp4'
|
43 |
+
frames = read_video_to_frames(video_name)
|
44 |
+
coeff = fit_3dmm_for_a_video(video_name, save=False)
|
45 |
+
area_percent = get_eye_area_percent(torch.tensor(coeff['id']), torch.tensor(coeff['exp']), face3d_helper)
|
46 |
+
writer = imageio.get_writer("1.mp4", fps=25)
|
47 |
+
for idx, frame in enumerate(frames):
|
48 |
+
frame = cv2.putText(frame, f"{area_percent[idx]:.2f}", org=(128,128), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=(255,0,0), thickness=1)
|
49 |
+
writer.append_data(frame)
|
50 |
+
writer.close()
|
data_gen/utils/process_video/extract_lm2d.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import pickle
|
7 |
+
import tqdm
|
8 |
+
import numpy as np
|
9 |
+
import mediapipe as mp
|
10 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
11 |
+
from utils.commons.os_utils import multiprocess_glob
|
12 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
|
13 |
+
import warnings
|
14 |
+
import traceback
|
15 |
+
|
16 |
+
warnings.filterwarnings('ignore')
|
17 |
+
|
18 |
+
"""
|
19 |
+
基于Face_aligment的lm68已被弃用,因为其:
|
20 |
+
1. 对眼睛部位的预测精度极低
|
21 |
+
2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
|
22 |
+
我们目前转而使用基于mediapipe的lm68
|
23 |
+
"""
|
24 |
+
# def extract_landmarks(ori_imgs_dir):
|
25 |
+
|
26 |
+
# print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
|
27 |
+
|
28 |
+
# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
|
29 |
+
# image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
|
30 |
+
# for image_path in tqdm.tqdm(image_paths):
|
31 |
+
# out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
|
32 |
+
# if os.path.exists(out_name):
|
33 |
+
# continue
|
34 |
+
# input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
35 |
+
# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
|
36 |
+
# preds = fa.get_landmarks(input)
|
37 |
+
# if preds is None:
|
38 |
+
# print(f"Skip {image_path} for no face detected")
|
39 |
+
# continue
|
40 |
+
# if len(preds) > 0:
|
41 |
+
# lands = preds[0].reshape(-1, 2)[:,:2]
|
42 |
+
# os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
43 |
+
# np.savetxt(out_name, lands, '%f')
|
44 |
+
# del fa
|
45 |
+
# print(f'[INFO] ===== extracted face landmarks =====')
|
46 |
+
|
47 |
+
def save_file(name, content):
|
48 |
+
with open(name, "wb") as f:
|
49 |
+
pickle.dump(content, f)
|
50 |
+
|
51 |
+
def load_file(name):
|
52 |
+
with open(name, "rb") as f:
|
53 |
+
content = pickle.load(f)
|
54 |
+
return content
|
55 |
+
|
56 |
+
|
57 |
+
face_landmarker = None
|
58 |
+
|
59 |
+
def extract_landmark_job(video_name, nerf=False):
|
60 |
+
try:
|
61 |
+
if nerf:
|
62 |
+
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
|
63 |
+
else:
|
64 |
+
out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
|
65 |
+
if os.path.exists(out_name):
|
66 |
+
# print("out exists, skip...")
|
67 |
+
return
|
68 |
+
try:
|
69 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
70 |
+
except:
|
71 |
+
pass
|
72 |
+
global face_landmarker
|
73 |
+
if face_landmarker is None:
|
74 |
+
face_landmarker = MediapipeLandmarker()
|
75 |
+
img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
|
76 |
+
lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
|
77 |
+
np.save(out_name, lm478)
|
78 |
+
return True
|
79 |
+
# print("Hahaha, solve one item!!!")
|
80 |
+
except Exception as e:
|
81 |
+
traceback.print_exc()
|
82 |
+
return False
|
83 |
+
|
84 |
+
def out_exist_job(vid_name):
|
85 |
+
out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
|
86 |
+
if os.path.exists(out_name):
|
87 |
+
return None
|
88 |
+
else:
|
89 |
+
return vid_name
|
90 |
+
|
91 |
+
def get_todo_vid_names(vid_names):
|
92 |
+
if len(vid_names) == 1: # nerf
|
93 |
+
return vid_names
|
94 |
+
todo_vid_names = []
|
95 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
|
96 |
+
if res is not None:
|
97 |
+
todo_vid_names.append(res)
|
98 |
+
return todo_vid_names
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
import argparse, glob, tqdm, random
|
102 |
+
parser = argparse.ArgumentParser()
|
103 |
+
parser.add_argument("--vid_dir", default='nerf')
|
104 |
+
parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
|
105 |
+
parser.add_argument("--num_workers", default=2, type=int)
|
106 |
+
parser.add_argument("--process_id", default=0, type=int)
|
107 |
+
parser.add_argument("--total_process", default=1, type=int)
|
108 |
+
parser.add_argument("--reset", action="store_true")
|
109 |
+
parser.add_argument("--load_names", action="store_true")
|
110 |
+
|
111 |
+
args = parser.parse_args()
|
112 |
+
vid_dir = args.vid_dir
|
113 |
+
ds_name = args.ds_name
|
114 |
+
load_names = args.load_names
|
115 |
+
|
116 |
+
if ds_name.lower() == 'nerf': # 处理单个视频
|
117 |
+
vid_names = [vid_dir]
|
118 |
+
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
|
119 |
+
else: # 处理整个数据集
|
120 |
+
if ds_name in ['lrs3_trainval']:
|
121 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
122 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
123 |
+
vid_name_pattern = os.path.join(vid_dir, "*.mp4")
|
124 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
|
125 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
126 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
127 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
128 |
+
else:
|
129 |
+
raise NotImplementedError()
|
130 |
+
|
131 |
+
vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
|
132 |
+
if os.path.exists(vid_names_path) and load_names:
|
133 |
+
print(f"loading vid names from {vid_names_path}")
|
134 |
+
vid_names = load_file(vid_names_path)
|
135 |
+
else:
|
136 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
137 |
+
vid_names = sorted(vid_names)
|
138 |
+
if not load_names:
|
139 |
+
print(f"saving vid names to {vid_names_path}")
|
140 |
+
save_file(vid_names_path, vid_names)
|
141 |
+
out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]
|
142 |
+
|
143 |
+
process_id = args.process_id
|
144 |
+
total_process = args.total_process
|
145 |
+
if total_process > 1:
|
146 |
+
assert process_id <= total_process -1
|
147 |
+
num_samples_per_process = len(vid_names) // total_process
|
148 |
+
if process_id == total_process:
|
149 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
150 |
+
else:
|
151 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
152 |
+
|
153 |
+
if not args.reset:
|
154 |
+
vid_names = get_todo_vid_names(vid_names)
|
155 |
+
print(f"todo videos number: {len(vid_names)}")
|
156 |
+
|
157 |
+
fail_cnt = 0
|
158 |
+
job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
|
159 |
+
for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
|
160 |
+
if res is False:
|
161 |
+
fail_cnt += 1
|
162 |
+
print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
|
163 |
+
sys.stdout.flush()
|
164 |
+
pass
|
data_gen/utils/process_video/extract_segment_imgs.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
import random
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
from utils.commons.tensor_utils import convert_to_np
|
10 |
+
from utils.commons.os_utils import multiprocess_glob
|
11 |
+
import pickle
|
12 |
+
import torch
|
13 |
+
import mediapipe as mp
|
14 |
+
import traceback
|
15 |
+
import multiprocessing
|
16 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
17 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
18 |
+
from sklearn.neighbors import NearestNeighbors
|
19 |
+
from mediapipe.tasks.python import vision
|
20 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter, encode_segmap_mask_to_image, decode_segmap_mask_from_image
|
21 |
+
|
22 |
+
seg_model = None
|
23 |
+
segmenter = None
|
24 |
+
mat_model = None
|
25 |
+
lama_model = None
|
26 |
+
lama_config = None
|
27 |
+
|
28 |
+
from data_gen.utils.process_video.split_video_to_imgs import extract_img_job
|
29 |
+
|
30 |
+
BG_NAME_MAP = {
|
31 |
+
"knn": "",
|
32 |
+
"mat": "_mat",
|
33 |
+
"ddnm": "_ddnm",
|
34 |
+
"lama": "_lama",
|
35 |
+
}
|
36 |
+
FRAME_SELECT_INTERVAL = 5
|
37 |
+
SIM_METHOD = "mse"
|
38 |
+
SIM_THRESHOLD = 3
|
39 |
+
|
40 |
+
def save_file(name, content):
|
41 |
+
with open(name, "wb") as f:
|
42 |
+
pickle.dump(content, f)
|
43 |
+
|
44 |
+
def load_file(name):
|
45 |
+
with open(name, "rb") as f:
|
46 |
+
content = pickle.load(f)
|
47 |
+
return content
|
48 |
+
|
49 |
+
def save_rgb_alpha_image_to_path(img, alpha, img_path):
|
50 |
+
try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
51 |
+
except: pass
|
52 |
+
cv2.imwrite(img_path, np.concatenate([cv2.cvtColor(img, cv2.COLOR_RGB2BGR), alpha], axis=-1))
|
53 |
+
|
54 |
+
def save_rgb_image_to_path(img, img_path):
|
55 |
+
try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
56 |
+
except: pass
|
57 |
+
cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
58 |
+
|
59 |
+
def load_rgb_image_to_path(img_path):
|
60 |
+
return cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
61 |
+
|
62 |
+
def image_similarity(x: np.ndarray, y: np.ndarray, method="mse"):
|
63 |
+
if method == "mse":
|
64 |
+
return np.mean((x - y) ** 2)
|
65 |
+
else:
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
def extract_background(img_lst, segmap_mask_lst=None, method="knn", device='cpu', mix_bg=True):
|
69 |
+
"""
|
70 |
+
img_lst: list of rgb ndarray
|
71 |
+
method: "knn", "mat" or "ddnm"
|
72 |
+
"""
|
73 |
+
# only use 1/20 images
|
74 |
+
global segmenter
|
75 |
+
global seg_model
|
76 |
+
global mat_model
|
77 |
+
global lama_model
|
78 |
+
global lama_config
|
79 |
+
|
80 |
+
assert len(img_lst) > 0
|
81 |
+
if segmap_mask_lst is not None:
|
82 |
+
assert len(segmap_mask_lst) == len(img_lst)
|
83 |
+
else:
|
84 |
+
del segmenter
|
85 |
+
del seg_model
|
86 |
+
seg_model = MediapipeSegmenter()
|
87 |
+
segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
|
88 |
+
|
89 |
+
def get_segmap_mask(img_lst, segmap_mask_lst, index):
|
90 |
+
if segmap_mask_lst is not None:
|
91 |
+
segmap = segmap_mask_lst[index]
|
92 |
+
else:
|
93 |
+
segmap = seg_model._cal_seg_map(img_lst[index], segmenter=segmenter)
|
94 |
+
return segmap
|
95 |
+
|
96 |
+
if method == "knn":
|
97 |
+
num_frames = len(img_lst)
|
98 |
+
img_lst = img_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else img_lst[0:1]
|
99 |
+
|
100 |
+
if segmap_mask_lst is not None:
|
101 |
+
segmap_mask_lst = segmap_mask_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else segmap_mask_lst[0:1]
|
102 |
+
assert len(img_lst) == len(segmap_mask_lst)
|
103 |
+
# get H/W
|
104 |
+
h, w = img_lst[0].shape[:2]
|
105 |
+
|
106 |
+
# nearest neighbors
|
107 |
+
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() # [512*512, 2] coordinate grid
|
108 |
+
distss = []
|
109 |
+
for idx, img in enumerate(img_lst):
|
110 |
+
segmap = get_segmap_mask(img_lst=img_lst, segmap_mask_lst=segmap_mask_lst, index=idx)
|
111 |
+
bg = (segmap[0]).astype(bool) # [h,w] bool mask
|
112 |
+
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) # [N_nonbg,2] coordinate of non-bg pixels
|
113 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
114 |
+
dists, _ = nbrs.kneighbors(all_xys) # [512*512, 1] distance to nearest non-bg pixel
|
115 |
+
distss.append(dists)
|
116 |
+
|
117 |
+
distss = np.stack(distss) # [B, 512*512, 1]
|
118 |
+
max_dist = np.max(distss, 0) # [512*512, 1]
|
119 |
+
max_id = np.argmax(distss, 0) # id of frame
|
120 |
+
|
121 |
+
bc_pixs = max_dist > 10 # 在各个frame有一个出现过是bg的pixel,bg标准是离最近的non-bg pixel距离大于10
|
122 |
+
bc_pixs_id = np.nonzero(bc_pixs)
|
123 |
+
bc_ids = max_id[bc_pixs]
|
124 |
+
|
125 |
+
num_pixs = distss.shape[1]
|
126 |
+
imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
|
127 |
+
|
128 |
+
bg_img = np.zeros((h*w, 3), dtype=np.uint8)
|
129 |
+
bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] # 对那些铁bg的pixel,直接去对应的image里面采样
|
130 |
+
bg_img = bg_img.reshape(h, w, 3)
|
131 |
+
|
132 |
+
max_dist = max_dist.reshape(h, w)
|
133 |
+
bc_pixs = max_dist > 10 # 5
|
134 |
+
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
135 |
+
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
136 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
137 |
+
distances, indices = nbrs.kneighbors(bg_xys) # 对non-bg img,���KNN找最近的bg pixel
|
138 |
+
bg_fg_xys = fg_xys[indices[:, 0]]
|
139 |
+
bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
140 |
+
else:
|
141 |
+
raise NotImplementedError # deperated
|
142 |
+
|
143 |
+
return bg_img
|
144 |
+
|
145 |
+
def inpaint_torso_job(gt_img, segmap):
|
146 |
+
bg_part = (segmap[0]).astype(bool)
|
147 |
+
head_part = (segmap[1] + segmap[3] + segmap[5]).astype(bool)
|
148 |
+
neck_part = (segmap[2]).astype(bool)
|
149 |
+
torso_part = (segmap[4]).astype(bool)
|
150 |
+
img = gt_img.copy()
|
151 |
+
img[head_part] = 0
|
152 |
+
|
153 |
+
# torso part "vertical" in-painting...
|
154 |
+
L = 8 + 1
|
155 |
+
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
156 |
+
# lexsort: sort 2D coords first by y then by x,
|
157 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
158 |
+
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
159 |
+
torso_coords = torso_coords[inds]
|
160 |
+
# choose the top pixel for each column
|
161 |
+
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
162 |
+
top_torso_coords = torso_coords[uid] # [m, 2]
|
163 |
+
# only keep top-is-head pixels
|
164 |
+
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
|
165 |
+
mask = head_part[tuple(top_torso_coords_up.T)]
|
166 |
+
if mask.any():
|
167 |
+
top_torso_coords = top_torso_coords[mask]
|
168 |
+
# get the color
|
169 |
+
top_torso_colors = gt_img[tuple(top_torso_coords.T)] # [m, 3]
|
170 |
+
# construct inpaint coords (vertically up, or minus in x)
|
171 |
+
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
172 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
173 |
+
inpaint_torso_coords += inpaint_offsets
|
174 |
+
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
175 |
+
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
176 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
177 |
+
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
178 |
+
# set color
|
179 |
+
img[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
180 |
+
inpaint_torso_mask = np.zeros_like(img[..., 0]).astype(bool)
|
181 |
+
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
182 |
+
else:
|
183 |
+
inpaint_torso_mask = None
|
184 |
+
|
185 |
+
# neck part "vertical" in-painting...
|
186 |
+
push_down = 4
|
187 |
+
L = 48 + push_down + 1
|
188 |
+
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
189 |
+
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
190 |
+
# lexsort: sort 2D coords first by y then by x,
|
191 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
192 |
+
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
193 |
+
neck_coords = neck_coords[inds]
|
194 |
+
# choose the top pixel for each column
|
195 |
+
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
196 |
+
top_neck_coords = neck_coords[uid] # [m, 2]
|
197 |
+
# only keep top-is-head pixels
|
198 |
+
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
199 |
+
mask = head_part[tuple(top_neck_coords_up.T)]
|
200 |
+
top_neck_coords = top_neck_coords[mask]
|
201 |
+
# push these top down for 4 pixels to make the neck inpainting more natural...
|
202 |
+
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
203 |
+
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
204 |
+
# get the color
|
205 |
+
top_neck_colors = gt_img[tuple(top_neck_coords.T)] # [m, 3]
|
206 |
+
# construct inpaint coords (vertically up, or minus in x)
|
207 |
+
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
208 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
209 |
+
inpaint_neck_coords += inpaint_offsets
|
210 |
+
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
211 |
+
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
212 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
213 |
+
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
214 |
+
# set color
|
215 |
+
img[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
216 |
+
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
217 |
+
inpaint_mask = np.zeros_like(img[..., 0]).astype(bool)
|
218 |
+
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
219 |
+
|
220 |
+
blur_img = img.copy()
|
221 |
+
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
222 |
+
img[inpaint_mask] = blur_img[inpaint_mask]
|
223 |
+
|
224 |
+
# set mask
|
225 |
+
torso_img_mask = (neck_part | torso_part | inpaint_mask)
|
226 |
+
torso_with_bg_img_mask = (bg_part | neck_part | torso_part | inpaint_mask)
|
227 |
+
if inpaint_torso_mask is not None:
|
228 |
+
torso_img_mask = torso_img_mask | inpaint_torso_mask
|
229 |
+
torso_with_bg_img_mask = torso_with_bg_img_mask | inpaint_torso_mask
|
230 |
+
|
231 |
+
torso_img = img.copy()
|
232 |
+
torso_img[~torso_img_mask] = 0
|
233 |
+
torso_with_bg_img = img.copy()
|
234 |
+
torso_img[~torso_with_bg_img_mask] = 0
|
235 |
+
|
236 |
+
return torso_img, torso_img_mask, torso_with_bg_img, torso_with_bg_img_mask
|
237 |
+
|
238 |
+
|
239 |
+
def extract_segment_job(video_name, nerf=False, idx=None, total=None, background_method='knn', device="cpu", total_gpus=0, mix_bg=True):
|
240 |
+
global segmenter
|
241 |
+
global seg_model
|
242 |
+
del segmenter
|
243 |
+
del seg_model
|
244 |
+
seg_model = MediapipeSegmenter()
|
245 |
+
segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
|
246 |
+
try:
|
247 |
+
if "cuda" in device:
|
248 |
+
# determine which cuda index from subprocess id
|
249 |
+
pname = multiprocessing.current_process().name
|
250 |
+
pid = int(pname.rsplit("-", 1)[-1]) - 1
|
251 |
+
cuda_id = pid % total_gpus
|
252 |
+
device = f"cuda:{cuda_id}"
|
253 |
+
|
254 |
+
if nerf: # single video
|
255 |
+
raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
|
256 |
+
else: # whole dataset
|
257 |
+
raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
|
258 |
+
if not os.path.exists(raw_img_dir):
|
259 |
+
extract_img_job(video_name, raw_img_dir) # use ffmpeg to split video into imgs
|
260 |
+
|
261 |
+
img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
|
262 |
+
|
263 |
+
img_lst = []
|
264 |
+
|
265 |
+
for img_name in img_names:
|
266 |
+
img = cv2.imread(img_name)
|
267 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
268 |
+
img_lst.append(img)
|
269 |
+
|
270 |
+
segmap_mask_lst, segmap_image_lst = seg_model._cal_seg_map_for_video(img_lst, segmenter=segmenter, return_onehot_mask=True, return_segmap_image=True)
|
271 |
+
del segmap_image_lst
|
272 |
+
# for i in range(len(img_lst)):
|
273 |
+
for i in tqdm.trange(len(img_lst), desc='generating segment images using segmaps...'):
|
274 |
+
img_name = img_names[i]
|
275 |
+
segmap = segmap_mask_lst[i]
|
276 |
+
img = img_lst[i]
|
277 |
+
out_img_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png") # 存成jpg的话,pixel value会有误差
|
278 |
+
try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
|
279 |
+
except: pass
|
280 |
+
encoded_segmap = encode_segmap_mask_to_image(segmap)
|
281 |
+
save_rgb_image_to_path(encoded_segmap, out_img_name)
|
282 |
+
|
283 |
+
for mode in ['head', 'torso', 'person', 'bg']:
|
284 |
+
out_img, mask = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
|
285 |
+
img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
|
286 |
+
mask = mask[0][..., None]
|
287 |
+
img_alpha[~mask] = 0
|
288 |
+
out_img_name = img_name.replace("/gt_imgs/", f"/{mode}_imgs/").replace(".jpg", ".png")
|
289 |
+
save_rgb_alpha_image_to_path(out_img, img_alpha, out_img_name)
|
290 |
+
|
291 |
+
inpaint_torso_img, inpaint_torso_img_mask, inpaint_torso_with_bg_img, inpaint_torso_with_bg_img_mask = inpaint_torso_job(img, segmap)
|
292 |
+
img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
|
293 |
+
img_alpha[~inpaint_torso_img_mask[..., None]] = 0
|
294 |
+
out_img_name = img_name.replace("/gt_imgs/", f"/inpaint_torso_imgs/").replace(".jpg", ".png")
|
295 |
+
save_rgb_alpha_image_to_path(inpaint_torso_img, img_alpha, out_img_name)
|
296 |
+
|
297 |
+
bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
|
298 |
+
bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
|
299 |
+
if nerf:
|
300 |
+
out_img_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
|
301 |
+
else:
|
302 |
+
out_img_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
|
303 |
+
save_rgb_image_to_path(bg_img, out_img_name)
|
304 |
+
|
305 |
+
com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
|
306 |
+
for i, img_name in enumerate(img_names):
|
307 |
+
com_img = img_lst[i].copy()
|
308 |
+
segmap = segmap_mask_lst[i]
|
309 |
+
bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
|
310 |
+
com_img[bg_part] = bg_img[bg_part]
|
311 |
+
out_img_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
|
312 |
+
save_rgb_image_to_path(com_img, out_img_name)
|
313 |
+
return 0
|
314 |
+
except Exception as e:
|
315 |
+
print(str(type(e)), e)
|
316 |
+
traceback.print_exc(e)
|
317 |
+
return 1
|
318 |
+
|
319 |
+
# def check_bg_img_job_finished(raw_img_dir, bg_name, com_dir):
|
320 |
+
# img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
|
321 |
+
# com_names = glob.glob(os.path.join(com_dir, "*.jpg"))
|
322 |
+
# return len(img_names) == len(com_names) and os.path.exists(bg_name)
|
323 |
+
|
324 |
+
# extract background and combined image
|
325 |
+
# need pre-processed "gt_imgs" and "segmaps"
|
326 |
+
def extract_bg_img_job(video_name, nerf=False, idx=None, total=None, background_method='knn', device="cpu", total_gpus=0, mix_bg=True):
|
327 |
+
try:
|
328 |
+
bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
|
329 |
+
com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
|
330 |
+
|
331 |
+
if "cuda" in device:
|
332 |
+
# determine which cuda index from subprocess id
|
333 |
+
pname = multiprocessing.current_process().name
|
334 |
+
pid = int(pname.rsplit("-", 1)[-1]) - 1
|
335 |
+
cuda_id = pid % total_gpus
|
336 |
+
device = f"cuda:{cuda_id}"
|
337 |
+
|
338 |
+
if nerf: # single video
|
339 |
+
raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
|
340 |
+
else: # whole dataset
|
341 |
+
raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
|
342 |
+
if nerf:
|
343 |
+
bg_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
|
344 |
+
else:
|
345 |
+
bg_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
|
346 |
+
# com_dir = raw_img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
|
347 |
+
# if check_bg_img_job_finished(raw_img_dir=raw_img_dir, bg_name=bg_name, com_dir=com_dir):
|
348 |
+
# print(f"Already finished, skip {raw_img_dir} ")
|
349 |
+
# return 0
|
350 |
+
|
351 |
+
img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
|
352 |
+
img_lst = []
|
353 |
+
for img_name in img_names:
|
354 |
+
img = cv2.imread(img_name)
|
355 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
356 |
+
img_lst.append(img)
|
357 |
+
|
358 |
+
segmap_mask_lst = []
|
359 |
+
for img_name in img_names:
|
360 |
+
segmap_img_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png")
|
361 |
+
segmap_img = load_rgb_image_to_path(segmap_img_name)
|
362 |
+
|
363 |
+
segmap_mask = decode_segmap_mask_from_image(segmap_img)
|
364 |
+
segmap_mask_lst.append(segmap_mask)
|
365 |
+
|
366 |
+
bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
|
367 |
+
save_rgb_image_to_path(bg_img, bg_name)
|
368 |
+
|
369 |
+
for i, img_name in enumerate(img_names):
|
370 |
+
com_img = img_lst[i].copy()
|
371 |
+
segmap = segmap_mask_lst[i]
|
372 |
+
bg_part = segmap[0].astype(bool)[..., None].repeat(3, axis=-1)
|
373 |
+
com_img[bg_part] = bg_img[bg_part]
|
374 |
+
com_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
|
375 |
+
save_rgb_image_to_path(com_img, com_name)
|
376 |
+
return 0
|
377 |
+
|
378 |
+
except Exception as e:
|
379 |
+
print(str(type(e)), e)
|
380 |
+
traceback.print_exc(e)
|
381 |
+
return 1
|
382 |
+
|
383 |
+
def out_exist_job(vid_name, background_method='knn', only_bg_img=False):
|
384 |
+
com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
|
385 |
+
img_dir = vid_name.replace("/video/", "/gt_imgs/").replace(".mp4", "")
|
386 |
+
out_dir1 = img_dir.replace("/gt_imgs/", "/head_imgs/")
|
387 |
+
out_dir2 = img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
|
388 |
+
|
389 |
+
if not only_bg_img:
|
390 |
+
if os.path.exists(img_dir) and os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) :
|
391 |
+
num_frames = len(os.listdir(img_dir))
|
392 |
+
if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames:
|
393 |
+
return None
|
394 |
+
else:
|
395 |
+
return vid_name
|
396 |
+
else:
|
397 |
+
return vid_name
|
398 |
+
else:
|
399 |
+
if os.path.exists(img_dir) and os.path.exists(out_dir2):
|
400 |
+
num_frames = len(os.listdir(img_dir))
|
401 |
+
if len(os.listdir(out_dir2)) == num_frames:
|
402 |
+
return None
|
403 |
+
else:
|
404 |
+
return vid_name
|
405 |
+
else:
|
406 |
+
return vid_name
|
407 |
+
|
408 |
+
def get_todo_vid_names(vid_names, background_method='knn', only_bg_img=False):
|
409 |
+
if len(vid_names) == 1: # nerf
|
410 |
+
return vid_names
|
411 |
+
todo_vid_names = []
|
412 |
+
fn_args = [(vid_name, background_method, only_bg_img) for vid_name in vid_names]
|
413 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, fn_args, num_workers=16, desc="checking todo videos..."):
|
414 |
+
if res is not None:
|
415 |
+
todo_vid_names.append(res)
|
416 |
+
return todo_vid_names
|
417 |
+
|
418 |
+
if __name__ == '__main__':
|
419 |
+
import argparse, glob, tqdm, random
|
420 |
+
parser = argparse.ArgumentParser()
|
421 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
422 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
423 |
+
parser.add_argument("--num_workers", default=48, type=int)
|
424 |
+
parser.add_argument("--seed", default=0, type=int)
|
425 |
+
parser.add_argument("--process_id", default=0, type=int)
|
426 |
+
parser.add_argument("--total_process", default=1, type=int)
|
427 |
+
parser.add_argument("--reset", action='store_true')
|
428 |
+
parser.add_argument("--load_names", action="store_true")
|
429 |
+
parser.add_argument("--background_method", choices=['knn', 'mat', 'ddnm', 'lama'], type=str, default='knn')
|
430 |
+
parser.add_argument("--total_gpus", default=0, type=int) # zero gpus means utilizing cpu
|
431 |
+
parser.add_argument("--only_bg_img", action="store_true")
|
432 |
+
parser.add_argument("--no_mix_bg", action="store_true")
|
433 |
+
|
434 |
+
args = parser.parse_args()
|
435 |
+
vid_dir = args.vid_dir
|
436 |
+
ds_name = args.ds_name
|
437 |
+
load_names = args.load_names
|
438 |
+
background_method = args.background_method
|
439 |
+
total_gpus = args.total_gpus
|
440 |
+
only_bg_img = args.only_bg_img
|
441 |
+
mix_bg = not args.no_mix_bg
|
442 |
+
|
443 |
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
|
444 |
+
for d in devices[:total_gpus]:
|
445 |
+
os.system(f'pkill -f "voidgpu{d}"')
|
446 |
+
|
447 |
+
if ds_name.lower() == 'nerf': # 处理单个视频
|
448 |
+
vid_names = [vid_dir]
|
449 |
+
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_lms.npy") for video_name in vid_names]
|
450 |
+
else: # 处理整个数据集
|
451 |
+
if ds_name in ['lrs3_trainval']:
|
452 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
453 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
454 |
+
vid_name_pattern = os.path.join(vid_dir, "*.mp4")
|
455 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
|
456 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
457 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
458 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
459 |
+
else:
|
460 |
+
raise NotImplementedError()
|
461 |
+
|
462 |
+
vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
|
463 |
+
if os.path.exists(vid_names_path) and load_names:
|
464 |
+
print(f"loading vid names from {vid_names_path}")
|
465 |
+
vid_names = load_file(vid_names_path)
|
466 |
+
else:
|
467 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
468 |
+
vid_names = sorted(vid_names)
|
469 |
+
print(f"saving vid names to {vid_names_path}")
|
470 |
+
save_file(vid_names_path, vid_names)
|
471 |
+
|
472 |
+
vid_names = sorted(vid_names)
|
473 |
+
random.seed(args.seed)
|
474 |
+
random.shuffle(vid_names)
|
475 |
+
|
476 |
+
process_id = args.process_id
|
477 |
+
total_process = args.total_process
|
478 |
+
if total_process > 1:
|
479 |
+
assert process_id <= total_process -1
|
480 |
+
num_samples_per_process = len(vid_names) // total_process
|
481 |
+
if process_id == total_process:
|
482 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
483 |
+
else:
|
484 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
485 |
+
|
486 |
+
if not args.reset:
|
487 |
+
vid_names = get_todo_vid_names(vid_names, background_method, only_bg_img)
|
488 |
+
print(f"todo videos number: {len(vid_names)}")
|
489 |
+
# exit()
|
490 |
+
|
491 |
+
device = "cuda" if total_gpus > 0 else "cpu"
|
492 |
+
if only_bg_img:
|
493 |
+
extract_job = extract_bg_img_job
|
494 |
+
fn_args = [(vid_name,ds_name=='nerf',i,len(vid_names), background_method, device, total_gpus, mix_bg) for i, vid_name in enumerate(vid_names)]
|
495 |
+
else:
|
496 |
+
extract_job = extract_segment_job
|
497 |
+
fn_args = [(vid_name,ds_name=='nerf',i,len(vid_names), background_method, device, total_gpus, mix_bg) for i, vid_name in enumerate(vid_names)]
|
498 |
+
|
499 |
+
for vid_name in multiprocess_run_tqdm(extract_job, fn_args, desc=f"Root process {args.process_id}: segment images", num_workers=args.num_workers):
|
500 |
+
pass
|
data_gen/utils/process_video/fit_3dmm_landmark.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a script for efficienct 3DMM coefficient extraction.
|
2 |
+
# It could reconstruct accurate 3D face in real-time.
|
3 |
+
# It is built upon BFM 2009 model and mediapipe landmark extractor.
|
4 |
+
# It is authored by ZhenhuiYe (zhenhuiye@zju.edu.cn), free to contact him for any suggestion on improvement!
|
5 |
+
|
6 |
+
from numpy.core.numeric import require
|
7 |
+
from numpy.lib.function_base import quantile
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import copy
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import random
|
14 |
+
import pickle
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import cv2
|
18 |
+
import argparse
|
19 |
+
import tqdm
|
20 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
21 |
+
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames
|
22 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
23 |
+
from deep_3drecon.secc_renderer import SECC_Renderer
|
24 |
+
from utils.commons.os_utils import multiprocess_glob
|
25 |
+
|
26 |
+
|
27 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
28 |
+
camera_distance=10, focal=1015, keypoint_mode='mediapipe')
|
29 |
+
face_model.to(torch.device("cuda:0"))
|
30 |
+
|
31 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
32 |
+
|
33 |
+
|
34 |
+
def draw_axes(img, pitch, yaw, roll, tx, ty, size=50):
|
35 |
+
# yaw = -yaw
|
36 |
+
pitch = - pitch
|
37 |
+
roll = - roll
|
38 |
+
rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64)
|
39 |
+
axes_points = np.array([
|
40 |
+
[1, 0, 0, 0],
|
41 |
+
[0, 1, 0, 0],
|
42 |
+
[0, 0, 1, 0]
|
43 |
+
], dtype=np.float64)
|
44 |
+
axes_points = rotation_matrix @ axes_points
|
45 |
+
axes_points = (axes_points[:2, :] * size).astype(int)
|
46 |
+
axes_points[0, :] = axes_points[0, :] + tx
|
47 |
+
axes_points[1, :] = axes_points[1, :] + ty
|
48 |
+
|
49 |
+
new_img = img.copy()
|
50 |
+
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3)
|
51 |
+
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3)
|
52 |
+
cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3)
|
53 |
+
return new_img
|
54 |
+
|
55 |
+
def save_file(name, content):
|
56 |
+
with open(name, "wb") as f:
|
57 |
+
pickle.dump(content, f)
|
58 |
+
|
59 |
+
def load_file(name):
|
60 |
+
with open(name, "rb") as f:
|
61 |
+
content = pickle.load(f)
|
62 |
+
return content
|
63 |
+
|
64 |
+
def cal_lap_loss(in_tensor):
|
65 |
+
# [T, 68, 2]
|
66 |
+
t = in_tensor.shape[0]
|
67 |
+
in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) # [c, 1, t]
|
68 |
+
in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1)
|
69 |
+
lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw]
|
70 |
+
loss_lap = 0
|
71 |
+
|
72 |
+
out_tensor = F.conv1d(in_tensor, lap_kernel)
|
73 |
+
loss_lap += torch.mean(out_tensor**2)
|
74 |
+
return loss_lap
|
75 |
+
|
76 |
+
def cal_vel_loss(ldm):
|
77 |
+
# [B, 68, 2]
|
78 |
+
vel = ldm[1:] - ldm[:-1]
|
79 |
+
return torch.mean(torch.abs(vel))
|
80 |
+
|
81 |
+
def cal_lan_loss(proj_lan, gt_lan):
|
82 |
+
# [B, 68, 2]
|
83 |
+
loss = (proj_lan - gt_lan)** 2
|
84 |
+
# use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
|
85 |
+
weights = torch.zeros_like(loss)
|
86 |
+
weights = torch.ones_like(loss)
|
87 |
+
weights[:, 36:48, :] = 3 # eye 12 points
|
88 |
+
weights[:, -8:, :] = 3 # inner lip 8 points
|
89 |
+
weights[:, 28:31, :] = 3 # nose 3 points
|
90 |
+
loss = loss * weights
|
91 |
+
return torch.mean(loss)
|
92 |
+
|
93 |
+
def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True):
|
94 |
+
# [B, 68, 2]
|
95 |
+
loss = (proj_lan - gt_lan).pow(2)
|
96 |
+
# loss = (proj_lan - gt_lan).abs()
|
97 |
+
unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
|
98 |
+
upper_eye = [161,160,159,158,157] + [388,387,386,385,384]
|
99 |
+
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
|
100 |
+
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
|
101 |
+
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
|
102 |
+
weights = torch.ones_like(loss)
|
103 |
+
weights[:, eye] = 3
|
104 |
+
weights[:, upper_eye] = 20
|
105 |
+
weights[:, inner_lip] = 5
|
106 |
+
weights[:, outer_lip] = 5
|
107 |
+
weights[:, unmatch_mask] = 0
|
108 |
+
loss = loss * weights
|
109 |
+
if mean:
|
110 |
+
loss = torch.mean(loss)
|
111 |
+
return loss
|
112 |
+
|
113 |
+
def cal_acceleration_loss(trans):
|
114 |
+
vel = trans[1:] - trans[:-1]
|
115 |
+
acc = vel[1:] - vel[:-1]
|
116 |
+
return torch.mean(torch.abs(acc))
|
117 |
+
|
118 |
+
def cal_acceleration_ldm_loss(ldm):
|
119 |
+
# [B, 68, 2]
|
120 |
+
vel = ldm[1:] - ldm[:-1]
|
121 |
+
acc = vel[1:] - vel[:-1]
|
122 |
+
lip_weight = 0.25 # we dont want smooth the lip too much
|
123 |
+
acc[48:68] *= lip_weight
|
124 |
+
return torch.mean(torch.abs(acc))
|
125 |
+
|
126 |
+
def set_requires_grad(tensor_list):
|
127 |
+
for tensor in tensor_list:
|
128 |
+
tensor.requires_grad = True
|
129 |
+
|
130 |
+
@torch.enable_grad()
|
131 |
+
def fit_3dmm_for_a_video(
|
132 |
+
video_name,
|
133 |
+
nerf=False, # use the file name convention for GeneFace++
|
134 |
+
id_mode='global',
|
135 |
+
debug=False,
|
136 |
+
keypoint_mode='mediapipe',
|
137 |
+
large_yaw_threshold=9999999.9,
|
138 |
+
save=True
|
139 |
+
) -> bool: # True: good, False: bad
|
140 |
+
assert video_name.endswith(".mp4"), "this function only support video as input"
|
141 |
+
if id_mode == 'global':
|
142 |
+
LAMBDA_REG_ID = 0.2
|
143 |
+
LAMBDA_REG_EXP = 0.6
|
144 |
+
LAMBDA_REG_LAP = 1.0
|
145 |
+
LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
|
146 |
+
LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
|
147 |
+
else:
|
148 |
+
LAMBDA_REG_ID = 0.3
|
149 |
+
LAMBDA_REG_EXP = 0.05
|
150 |
+
LAMBDA_REG_LAP = 1.0
|
151 |
+
LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
|
152 |
+
LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
|
153 |
+
|
154 |
+
frames = read_video_to_frames(video_name) # [T, H, W, 3]
|
155 |
+
img_h, img_w = frames.shape[1], frames.shape[2]
|
156 |
+
assert img_h == img_w
|
157 |
+
num_frames = len(frames)
|
158 |
+
|
159 |
+
if nerf: # single video
|
160 |
+
lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
|
161 |
+
else:
|
162 |
+
lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy")
|
163 |
+
|
164 |
+
if os.path.exists(lm_name):
|
165 |
+
lms = np.load(lm_name)
|
166 |
+
else:
|
167 |
+
print(f"lms_2d file not found, try to extract it from video... {lm_name}")
|
168 |
+
try:
|
169 |
+
landmarker = MediapipeLandmarker()
|
170 |
+
img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20)
|
171 |
+
lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
|
172 |
+
except Exception as e:
|
173 |
+
print(e)
|
174 |
+
return False
|
175 |
+
if lms is None:
|
176 |
+
print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}")
|
177 |
+
return False
|
178 |
+
lms = lms[:, :468, :]
|
179 |
+
lms = torch.FloatTensor(lms).cuda()
|
180 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
181 |
+
|
182 |
+
if keypoint_mode == 'mediapipe':
|
183 |
+
# default
|
184 |
+
cal_lan_loss_fn = cal_lan_loss_mp
|
185 |
+
if nerf: # single video
|
186 |
+
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy")
|
187 |
+
else:
|
188 |
+
out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
|
189 |
+
else:
|
190 |
+
# lm68 is less accurate than mp
|
191 |
+
cal_lan_loss_fn = cal_lan_loss
|
192 |
+
if nerf: # single video
|
193 |
+
out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy")
|
194 |
+
else:
|
195 |
+
out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy")
|
196 |
+
try:
|
197 |
+
os.makedirs(os.path.dirname(out_name), exist_ok=True)
|
198 |
+
except:
|
199 |
+
pass
|
200 |
+
|
201 |
+
id_dim, exp_dim = 80, 64
|
202 |
+
sel_ids = np.arange(0, num_frames, 40)
|
203 |
+
|
204 |
+
h = w = face_model.center * 2
|
205 |
+
img_scale_factor = img_h / h
|
206 |
+
lms /= img_scale_factor # rescale lms into [0,224]
|
207 |
+
|
208 |
+
if id_mode == 'global':
|
209 |
+
# default choice by GeneFace++ and later works
|
210 |
+
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
|
211 |
+
elif id_mode == 'finegrained':
|
212 |
+
# legacy choice by GeneFace1 (ICLR 2023)
|
213 |
+
id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True)
|
214 |
+
else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.")
|
215 |
+
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
216 |
+
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
|
217 |
+
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
|
218 |
+
|
219 |
+
set_requires_grad([id_para, exp_para, euler_angle, trans])
|
220 |
+
|
221 |
+
optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
|
222 |
+
optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
|
223 |
+
|
224 |
+
# 其他参数初始化,先训练euler和trans
|
225 |
+
for _ in range(200):
|
226 |
+
if id_mode == 'global':
|
227 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
228 |
+
id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans)
|
229 |
+
else:
|
230 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
231 |
+
id_para, exp_para, euler_angle, trans)
|
232 |
+
loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
|
233 |
+
loss = loss_lan
|
234 |
+
optimizer_frame.zero_grad()
|
235 |
+
loss.backward()
|
236 |
+
optimizer_frame.step()
|
237 |
+
|
238 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
239 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
240 |
+
|
241 |
+
for param_group in optimizer_frame.param_groups:
|
242 |
+
param_group['lr'] = 0.1
|
243 |
+
|
244 |
+
# "jointly roughly training id exp euler trans"
|
245 |
+
for _ in range(200):
|
246 |
+
ret = {}
|
247 |
+
if id_mode == 'global':
|
248 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
249 |
+
id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret)
|
250 |
+
else:
|
251 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
252 |
+
id_para, exp_para, euler_angle, trans, ret)
|
253 |
+
loss_lan = cal_lan_loss_fn(
|
254 |
+
proj_geo[:, :, :2], lms.detach())
|
255 |
+
# loss_lap = cal_lap_loss(proj_geo)
|
256 |
+
# laplacian对euler影响不大,但是对trans的提升很大
|
257 |
+
loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3
|
258 |
+
|
259 |
+
loss_regid = torch.mean(id_para*id_para) # 正则化
|
260 |
+
loss_regexp = torch.mean(exp_para * exp_para)
|
261 |
+
|
262 |
+
loss_vel_id = cal_vel_loss(id_para)
|
263 |
+
loss_vel_exp = cal_vel_loss(exp_para)
|
264 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP
|
265 |
+
optimizer_idexp.zero_grad()
|
266 |
+
optimizer_frame.zero_grad()
|
267 |
+
loss.backward()
|
268 |
+
optimizer_idexp.step()
|
269 |
+
optimizer_frame.step()
|
270 |
+
|
271 |
+
# print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
|
272 |
+
# print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
|
273 |
+
# print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
|
274 |
+
|
275 |
+
# start fine training, intialize from the roughly trained results
|
276 |
+
if id_mode == 'global':
|
277 |
+
id_para_ = lms.new_zeros((1, id_dim), requires_grad=False)
|
278 |
+
else:
|
279 |
+
id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
|
280 |
+
id_para_.data = id_para.data.clone()
|
281 |
+
id_para = id_para_
|
282 |
+
exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
|
283 |
+
exp_para_.data = exp_para.data.clone()
|
284 |
+
exp_para = exp_para_
|
285 |
+
euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
286 |
+
euler_angle_.data = euler_angle.data.clone()
|
287 |
+
euler_angle = euler_angle_
|
288 |
+
trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
|
289 |
+
trans_.data = trans.data.clone()
|
290 |
+
trans = trans_
|
291 |
+
|
292 |
+
batch_size = 50
|
293 |
+
# "fine fitting the 3DMM in batches"
|
294 |
+
for i in range(int((num_frames-1)/batch_size+1)):
|
295 |
+
if (i+1)*batch_size > num_frames:
|
296 |
+
start_n = num_frames-batch_size
|
297 |
+
sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
|
298 |
+
else:
|
299 |
+
start_n = i*batch_size
|
300 |
+
sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
|
301 |
+
sel_lms = lms[sel_ids]
|
302 |
+
|
303 |
+
if id_mode == 'global':
|
304 |
+
sel_id_para = id_para.expand((sel_ids.shape[0], id_dim))
|
305 |
+
else:
|
306 |
+
sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True)
|
307 |
+
sel_id_para.data = id_para[sel_ids].clone()
|
308 |
+
sel_exp_para = exp_para.new_zeros(
|
309 |
+
(batch_size, exp_dim), requires_grad=True)
|
310 |
+
sel_exp_para.data = exp_para[sel_ids].clone()
|
311 |
+
sel_euler_angle = euler_angle.new_zeros(
|
312 |
+
(batch_size, 3), requires_grad=True)
|
313 |
+
sel_euler_angle.data = euler_angle[sel_ids].clone()
|
314 |
+
sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
|
315 |
+
sel_trans.data = trans[sel_ids].clone()
|
316 |
+
|
317 |
+
if id_mode == 'global':
|
318 |
+
set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans])
|
319 |
+
optimizer_cur_batch = torch.optim.Adam(
|
320 |
+
[sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
|
321 |
+
else:
|
322 |
+
set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
|
323 |
+
optimizer_cur_batch = torch.optim.Adam(
|
324 |
+
[sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
|
325 |
+
|
326 |
+
for j in range(50):
|
327 |
+
ret = {}
|
328 |
+
proj_geo = face_model.compute_for_landmark_fit(
|
329 |
+
sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret)
|
330 |
+
loss_lan = cal_lan_loss_fn(
|
331 |
+
proj_geo[:, :, :2], lms[sel_ids].detach())
|
332 |
+
|
333 |
+
# loss_lap = cal_lap_loss(proj_geo)
|
334 |
+
loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3
|
335 |
+
loss_vel_id = cal_vel_loss(sel_id_para)
|
336 |
+
loss_vel_exp = cal_vel_loss(sel_exp_para)
|
337 |
+
log_dict = {
|
338 |
+
'loss_vel_id': loss_vel_id,
|
339 |
+
'loss_vel_exp': loss_vel_exp,
|
340 |
+
'loss_vel_euler': cal_vel_loss(sel_euler_angle),
|
341 |
+
'loss_vel_trans': cal_vel_loss(sel_trans),
|
342 |
+
}
|
343 |
+
loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
|
344 |
+
loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
|
345 |
+
loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP
|
346 |
+
|
347 |
+
optimizer_cur_batch.zero_grad()
|
348 |
+
loss.backward()
|
349 |
+
optimizer_cur_batch.step()
|
350 |
+
|
351 |
+
if debug:
|
352 |
+
print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}")
|
353 |
+
print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()]))
|
354 |
+
if id_mode != 'global':
|
355 |
+
id_para[sel_ids].data = sel_id_para.data.clone()
|
356 |
+
exp_para[sel_ids].data = sel_exp_para.data.clone()
|
357 |
+
euler_angle[sel_ids].data = sel_euler_angle.data.clone()
|
358 |
+
trans[sel_ids].data = sel_trans.data.clone()
|
359 |
+
|
360 |
+
coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
|
361 |
+
'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
|
362 |
+
|
363 |
+
# filter data by side-view pose
|
364 |
+
# bad_yaw = False
|
365 |
+
# yaws = [] # not so accurate
|
366 |
+
# for index in range(coeff_dict["trans"].shape[0]):
|
367 |
+
# yaw = coeff_dict["euler"][index][1]
|
368 |
+
# yaw = np.abs(yaw)
|
369 |
+
# yaws.append(yaw)
|
370 |
+
# if yaw > large_yaw_threshold:
|
371 |
+
# bad_yaw = True
|
372 |
+
|
373 |
+
if debug:
|
374 |
+
import imageio
|
375 |
+
from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer
|
376 |
+
from data_util.face3d_helper import Face3DHelper
|
377 |
+
from data_gen.utils.process_video.extract_blink import get_eye_area_percent
|
378 |
+
face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe')
|
379 |
+
|
380 |
+
t = coeff_dict['exp'].shape[0]
|
381 |
+
if len(coeff_dict['id']) == 1:
|
382 |
+
coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0)
|
383 |
+
idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1])
|
384 |
+
cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy()
|
385 |
+
cano_lm3d = cano_lm3d.reshape([t, -1, 3])
|
386 |
+
WH = 512
|
387 |
+
cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int)
|
388 |
+
|
389 |
+
with torch.no_grad():
|
390 |
+
rot = ParametricFaceModel.compute_rotation(euler_angle)
|
391 |
+
extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device)
|
392 |
+
extrinsic[:, :3,:3] = rot
|
393 |
+
extrinsic[:, :3, 3] = trans # / 10
|
394 |
+
extrinsic[:, 3, 3] = 1
|
395 |
+
extrinsic = extrinsic.cpu().numpy()
|
396 |
+
|
397 |
+
xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy')
|
398 |
+
xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz')
|
399 |
+
|
400 |
+
if nerf:
|
401 |
+
debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4")
|
402 |
+
else:
|
403 |
+
debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4")
|
404 |
+
try:
|
405 |
+
os.makedirs(os.path.dirname(debug_name), exist_ok=True)
|
406 |
+
except: pass
|
407 |
+
writer = imageio.get_writer(debug_name, fps=25)
|
408 |
+
if id_mode == 'global':
|
409 |
+
id_para = id_para.repeat([exp_para.shape[0], 1])
|
410 |
+
proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
|
411 |
+
lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
|
412 |
+
lm68s = lm68s * img_scale_factor
|
413 |
+
lms = lms * img_scale_factor
|
414 |
+
lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
|
415 |
+
lms[..., 1] = img_h - lms[..., 1] # flip the height axis
|
416 |
+
lm68s = lm68s.astype(int)
|
417 |
+
for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'):
|
418 |
+
xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
|
419 |
+
xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512))
|
420 |
+
xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
|
421 |
+
xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512))
|
422 |
+
|
423 |
+
img = copy.deepcopy(frames[i])
|
424 |
+
img2 = copy.deepcopy(frames[i])
|
425 |
+
|
426 |
+
img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50)
|
427 |
+
|
428 |
+
gt_lm_color = (255, 0, 0)
|
429 |
+
|
430 |
+
for lm in lm68s[i]:
|
431 |
+
img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) # blue
|
432 |
+
for gt_lm in lms[i]:
|
433 |
+
img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1)
|
434 |
+
|
435 |
+
cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
|
436 |
+
for j in range(len(cano_lm3d[i])):
|
437 |
+
x, y, _ = cano_lm3d[i, j]
|
438 |
+
color = (255,0,0)
|
439 |
+
cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1)
|
440 |
+
cano_lm3d_img = cv2.flip(cano_lm3d_img, 0)
|
441 |
+
|
442 |
+
_, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0)
|
443 |
+
secc_img = (secc_img +1)*127.5
|
444 |
+
secc_img = F.interpolate(secc_img, size=(img_h, img_w))
|
445 |
+
secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0]
|
446 |
+
out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8)
|
447 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
448 |
+
out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8)
|
449 |
+
out_img = np.concatenate([out_img1, out_img2], axis=0)
|
450 |
+
writer.append_data(out_img)
|
451 |
+
writer.close()
|
452 |
+
|
453 |
+
# if bad_yaw:
|
454 |
+
# print(f"Skip {video_name} due to TOO LARGE YAW")
|
455 |
+
# return False
|
456 |
+
|
457 |
+
if save:
|
458 |
+
np.save(out_name, coeff_dict, allow_pickle=True)
|
459 |
+
return coeff_dict
|
460 |
+
|
461 |
+
def out_exist_job(vid_name):
|
462 |
+
out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy")
|
463 |
+
lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
|
464 |
+
if os.path.exists(out_name) or not os.path.exists(lms_name):
|
465 |
+
return None
|
466 |
+
else:
|
467 |
+
return vid_name
|
468 |
+
|
469 |
+
def get_todo_vid_names(vid_names):
|
470 |
+
if len(vid_names) == 1: # single video, nerf
|
471 |
+
return vid_names
|
472 |
+
todo_vid_names = []
|
473 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
|
474 |
+
if res is not None:
|
475 |
+
todo_vid_names.append(res)
|
476 |
+
return todo_vid_names
|
477 |
+
|
478 |
+
|
479 |
+
if __name__ == '__main__':
|
480 |
+
import argparse, glob, tqdm
|
481 |
+
parser = argparse.ArgumentParser()
|
482 |
+
# parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
483 |
+
parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4')
|
484 |
+
parser.add_argument("--ds_name", default='nerf') # 'nerf' | 'CelebV-HQ' | 'TH1KH_512' | etc
|
485 |
+
parser.add_argument("--seed", default=0, type=int)
|
486 |
+
parser.add_argument("--process_id", default=0, type=int)
|
487 |
+
parser.add_argument("--total_process", default=1, type=int)
|
488 |
+
parser.add_argument("--id_mode", default='global', type=str) # global | finegrained
|
489 |
+
parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
|
490 |
+
parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) # could be 0.7
|
491 |
+
parser.add_argument("--debug", action='store_true')
|
492 |
+
parser.add_argument("--reset", action='store_true')
|
493 |
+
parser.add_argument("--load_names", action="store_true")
|
494 |
+
|
495 |
+
args = parser.parse_args()
|
496 |
+
vid_dir = args.vid_dir
|
497 |
+
ds_name = args.ds_name
|
498 |
+
load_names = args.load_names
|
499 |
+
|
500 |
+
print(f"args {args}")
|
501 |
+
|
502 |
+
if ds_name.lower() == 'nerf': # 处理单个视频
|
503 |
+
vid_names = [vid_dir]
|
504 |
+
out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
|
505 |
+
else: # 处理整个数据集
|
506 |
+
if ds_name in ['lrs3_trainval']:
|
507 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
508 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
509 |
+
vid_name_pattern = os.path.join(vid_dir, "*.mp4")
|
510 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
|
511 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
512 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
513 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
514 |
+
else:
|
515 |
+
raise NotImplementedError()
|
516 |
+
|
517 |
+
vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
|
518 |
+
if os.path.exists(vid_names_path) and load_names:
|
519 |
+
print(f"loading vid names from {vid_names_path}")
|
520 |
+
vid_names = load_file(vid_names_path)
|
521 |
+
else:
|
522 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
523 |
+
vid_names = sorted(vid_names)
|
524 |
+
print(f"saving vid names to {vid_names_path}")
|
525 |
+
save_file(vid_names_path, vid_names)
|
526 |
+
out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
|
527 |
+
|
528 |
+
print(vid_names[:10])
|
529 |
+
random.seed(args.seed)
|
530 |
+
random.shuffle(vid_names)
|
531 |
+
|
532 |
+
face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
|
533 |
+
camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
|
534 |
+
face_model.to(torch.device("cuda:0"))
|
535 |
+
secc_renderer = SECC_Renderer(512)
|
536 |
+
secc_renderer.to("cuda:0")
|
537 |
+
|
538 |
+
process_id = args.process_id
|
539 |
+
total_process = args.total_process
|
540 |
+
if total_process > 1:
|
541 |
+
assert process_id <= total_process -1
|
542 |
+
num_samples_per_process = len(vid_names) // total_process
|
543 |
+
if process_id == total_process:
|
544 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
545 |
+
else:
|
546 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
547 |
+
|
548 |
+
if not args.reset:
|
549 |
+
vid_names = get_todo_vid_names(vid_names)
|
550 |
+
|
551 |
+
failed_img_names = []
|
552 |
+
for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."):
|
553 |
+
img_name = vid_names[i]
|
554 |
+
try:
|
555 |
+
is_person_specific_data = ds_name=='nerf'
|
556 |
+
success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold)
|
557 |
+
if not success:
|
558 |
+
failed_img_names.append(img_name)
|
559 |
+
except Exception as e:
|
560 |
+
print(img_name, e)
|
561 |
+
failed_img_names.append(img_name)
|
562 |
+
print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}")
|
563 |
+
sys.stdout.flush()
|
564 |
+
print(f"all failed image names: {failed_img_names}")
|
565 |
+
print(f"All finished!")
|
data_gen/utils/process_video/inpaint_torso_imgs.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
5 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
6 |
+
|
7 |
+
from tasks.eg3ds.loss_utils.segment_loss.mp_segmenter import MediapipeSegmenter
|
8 |
+
seg_model = MediapipeSegmenter()
|
9 |
+
|
10 |
+
def inpaint_torso_job(video_name, idx=None, total=None):
|
11 |
+
raw_img_dir = video_name.replace(".mp4", "").replace("/video/","/gt_imgs/")
|
12 |
+
img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
|
13 |
+
|
14 |
+
for image_path in tqdm.tqdm(img_names):
|
15 |
+
# read ori image
|
16 |
+
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
17 |
+
segmap = seg_model._cal_seg_map(cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB))
|
18 |
+
head_part = (segmap[1] + segmap[3] + segmap[5]).astype(np.bool)
|
19 |
+
torso_part = (segmap[4]).astype(np.bool)
|
20 |
+
neck_part = (segmap[2]).astype(np.bool)
|
21 |
+
bg_part = segmap[0].astype(np.bool)
|
22 |
+
head_image = cv2.imread(image_path.replace("/gt_imgs/", "/head_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
23 |
+
torso_image = cv2.imread(image_path.replace("/gt_imgs/", "/torso_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
24 |
+
bg_image = cv2.imread(image_path.replace("/gt_imgs/", "/bg_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
25 |
+
|
26 |
+
# head_part = (head_image[...,0] != 0) & (head_image[...,1] != 0) & (head_image[...,2] != 0)
|
27 |
+
# torso_part = (torso_image[...,0] != 0) & (torso_image[...,1] != 0) & (torso_image[...,2] != 0)
|
28 |
+
# bg_part = (bg_image[...,0] != 0) & (bg_image[...,1] != 0) & (bg_image[...,2] != 0)
|
29 |
+
|
30 |
+
# get gt image
|
31 |
+
gt_image = ori_image.copy()
|
32 |
+
gt_image[bg_part] = bg_image[bg_part]
|
33 |
+
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
|
34 |
+
|
35 |
+
# get torso image
|
36 |
+
torso_image = gt_image.copy() # rgb
|
37 |
+
torso_image[head_part] = 0
|
38 |
+
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
|
39 |
+
|
40 |
+
# torso part "vertical" in-painting...
|
41 |
+
L = 8 + 1
|
42 |
+
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
43 |
+
# lexsort: sort 2D coords first by y then by x,
|
44 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
45 |
+
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
46 |
+
torso_coords = torso_coords[inds]
|
47 |
+
# choose the top pixel for each column
|
48 |
+
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
49 |
+
top_torso_coords = torso_coords[uid] # [m, 2]
|
50 |
+
# only keep top-is-head pixels
|
51 |
+
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
|
52 |
+
mask = head_part[tuple(top_torso_coords_up.T)]
|
53 |
+
if mask.any():
|
54 |
+
top_torso_coords = top_torso_coords[mask]
|
55 |
+
# get the color
|
56 |
+
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
|
57 |
+
# construct inpaint coords (vertically up, or minus in x)
|
58 |
+
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
59 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
60 |
+
inpaint_torso_coords += inpaint_offsets
|
61 |
+
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
62 |
+
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
63 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
64 |
+
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
65 |
+
# set color
|
66 |
+
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
67 |
+
|
68 |
+
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
69 |
+
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
70 |
+
else:
|
71 |
+
inpaint_torso_mask = None
|
72 |
+
|
73 |
+
# neck part "vertical" in-painting...
|
74 |
+
push_down = 4
|
75 |
+
L = 48 + push_down + 1
|
76 |
+
|
77 |
+
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
78 |
+
|
79 |
+
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
80 |
+
# lexsort: sort 2D coords first by y then by x,
|
81 |
+
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
82 |
+
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
83 |
+
neck_coords = neck_coords[inds]
|
84 |
+
# choose the top pixel for each column
|
85 |
+
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
86 |
+
top_neck_coords = neck_coords[uid] # [m, 2]
|
87 |
+
# only keep top-is-head pixels
|
88 |
+
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
89 |
+
mask = head_part[tuple(top_neck_coords_up.T)]
|
90 |
+
|
91 |
+
top_neck_coords = top_neck_coords[mask]
|
92 |
+
# push these top down for 4 pixels to make the neck inpainting more natural...
|
93 |
+
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
94 |
+
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
95 |
+
# get the color
|
96 |
+
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
|
97 |
+
# construct inpaint coords (vertically up, or minus in x)
|
98 |
+
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
99 |
+
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
100 |
+
inpaint_neck_coords += inpaint_offsets
|
101 |
+
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
102 |
+
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
103 |
+
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
104 |
+
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
105 |
+
# set color
|
106 |
+
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
107 |
+
|
108 |
+
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
109 |
+
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
110 |
+
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
111 |
+
|
112 |
+
blur_img = torso_image.copy()
|
113 |
+
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
114 |
+
|
115 |
+
torso_image[inpaint_mask] = blur_img[inpaint_mask]
|
116 |
+
|
117 |
+
# set mask
|
118 |
+
mask = (neck_part | torso_part | inpaint_mask)
|
119 |
+
if inpaint_torso_mask is not None:
|
120 |
+
mask = mask | inpaint_torso_mask
|
121 |
+
torso_image[~mask] = 0
|
122 |
+
torso_alpha[~mask] = 0
|
123 |
+
|
124 |
+
cv2.imwrite("0.png", np.concatenate([torso_image, torso_alpha], axis=-1))
|
125 |
+
|
126 |
+
print(f'[INFO] ===== extracted torso and gt images =====')
|
127 |
+
|
128 |
+
|
129 |
+
def out_exist_job(vid_name):
|
130 |
+
out_dir1 = vid_name.replace("/video/", "/inpaint_torso_imgs/").replace(".mp4","")
|
131 |
+
out_dir2 = vid_name.replace("/video/", "/inpaint_torso_with_bg_imgs/").replace(".mp4","")
|
132 |
+
out_dir3 = vid_name.replace("/video/", "/torso_imgs/").replace(".mp4","")
|
133 |
+
out_dir4 = vid_name.replace("/video/", "/torso_with_bg_imgs/").replace(".mp4","")
|
134 |
+
|
135 |
+
if os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) and os.path.exists(out_dir3) and os.path.exists(out_dir4):
|
136 |
+
num_frames = len(os.listdir(out_dir1))
|
137 |
+
if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames and len(os.listdir(out_dir3)) == num_frames and len(os.listdir(out_dir4)) == num_frames:
|
138 |
+
return None
|
139 |
+
else:
|
140 |
+
return vid_name
|
141 |
+
else:
|
142 |
+
return vid_name
|
143 |
+
|
144 |
+
def get_todo_vid_names(vid_names):
|
145 |
+
todo_vid_names = []
|
146 |
+
for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
|
147 |
+
if res is not None:
|
148 |
+
todo_vid_names.append(res)
|
149 |
+
return todo_vid_names
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
import argparse, glob, tqdm, random
|
153 |
+
parser = argparse.ArgumentParser()
|
154 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
155 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
156 |
+
parser.add_argument("--num_workers", default=48, type=int)
|
157 |
+
parser.add_argument("--seed", default=0, type=int)
|
158 |
+
parser.add_argument("--process_id", default=0, type=int)
|
159 |
+
parser.add_argument("--total_process", default=1, type=int)
|
160 |
+
parser.add_argument("--reset", action='store_true')
|
161 |
+
|
162 |
+
inpaint_torso_job('/home/tiger/datasets/raw/CelebV-HQ/video/dgdEr-mXQT4_8.mp4')
|
163 |
+
# args = parser.parse_args()
|
164 |
+
# vid_dir = args.vid_dir
|
165 |
+
# ds_name = args.ds_name
|
166 |
+
# if ds_name in ['lrs3_trainval']:
|
167 |
+
# mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
168 |
+
# if ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
169 |
+
# vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
|
170 |
+
# elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
|
171 |
+
# vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
172 |
+
# vid_names = glob.glob(vid_name_pattern)
|
173 |
+
# vid_names = sorted(vid_names)
|
174 |
+
# random.seed(args.seed)
|
175 |
+
# random.shuffle(vid_names)
|
176 |
+
|
177 |
+
# process_id = args.process_id
|
178 |
+
# total_process = args.total_process
|
179 |
+
# if total_process > 1:
|
180 |
+
# assert process_id <= total_process -1
|
181 |
+
# num_samples_per_process = len(vid_names) // total_process
|
182 |
+
# if process_id == total_process:
|
183 |
+
# vid_names = vid_names[process_id * num_samples_per_process : ]
|
184 |
+
# else:
|
185 |
+
# vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
186 |
+
|
187 |
+
# if not args.reset:
|
188 |
+
# vid_names = get_todo_vid_names(vid_names)
|
189 |
+
# print(f"todo videos number: {len(vid_names)}")
|
190 |
+
|
191 |
+
# fn_args = [(vid_name,i,len(vid_names)) for i, vid_name in enumerate(vid_names)]
|
192 |
+
# for vid_name in multiprocess_run_tqdm(inpaint_torso_job ,fn_args, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
|
193 |
+
# pass
|
data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
import cv2
|
3 |
+
from utils.commons.os_utils import multiprocess_glob
|
4 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
5 |
+
|
6 |
+
def get_video_infos(video_path):
|
7 |
+
vid_cap = cv2.VideoCapture(video_path)
|
8 |
+
height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
9 |
+
width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
10 |
+
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
11 |
+
total_frames = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
12 |
+
return {'height': height, 'width': width, 'fps': fps, 'total_frames':total_frames}
|
13 |
+
|
14 |
+
def extract_img_job(video_name:str):
|
15 |
+
out_path = video_name.replace("/video_raw/","/video/",1)
|
16 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
17 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
18 |
+
vid_info = get_video_infos(video_name)
|
19 |
+
assert vid_info['width'] == vid_info['height']
|
20 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
|
21 |
+
os.system(cmd)
|
22 |
+
|
23 |
+
def extract_img_job_crop(video_name:str):
|
24 |
+
out_path = video_name.replace("/video_raw/","/video/",1)
|
25 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
26 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
27 |
+
vid_info = get_video_infos(video_name)
|
28 |
+
wh = min(vid_info['width'], vid_info['height'])
|
29 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop={wh}:{wh},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
|
30 |
+
os.system(cmd)
|
31 |
+
|
32 |
+
def extract_img_job_crop_ravdess(video_name:str):
|
33 |
+
out_path = video_name.replace("/video_raw/","/video/",1)
|
34 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
35 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
36 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop=720:720,scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
|
37 |
+
os.system(cmd)
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
import argparse, glob, tqdm, random
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video_raw/')
|
43 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
44 |
+
parser.add_argument("--num_workers", default=32, type=int)
|
45 |
+
parser.add_argument("--process_id", default=0, type=int)
|
46 |
+
parser.add_argument("--total_process", default=1, type=int)
|
47 |
+
args = parser.parse_args()
|
48 |
+
print(f"args {args}")
|
49 |
+
|
50 |
+
vid_dir = args.vid_dir
|
51 |
+
ds_name = args.ds_name
|
52 |
+
if ds_name in ['lrs3_trainval']:
|
53 |
+
mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
54 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
55 |
+
vid_names = multiprocess_glob(os.path.join(vid_dir, "*.mp4"))
|
56 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
|
57 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
58 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
59 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
60 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
61 |
+
vid_names = multiprocess_glob(vid_name_pattern)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError()
|
64 |
+
vid_names = sorted(vid_names)
|
65 |
+
print(f"total video number : {len(vid_names)}")
|
66 |
+
print(f"first {vid_names[0]} last {vid_names[-1]}")
|
67 |
+
# exit()
|
68 |
+
process_id = args.process_id
|
69 |
+
total_process = args.total_process
|
70 |
+
if total_process > 1:
|
71 |
+
assert process_id <= total_process -1
|
72 |
+
num_samples_per_process = len(vid_names) // total_process
|
73 |
+
if process_id == total_process:
|
74 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
75 |
+
else:
|
76 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
77 |
+
|
78 |
+
if ds_name == "RAVDESS":
|
79 |
+
for i, res in multiprocess_run_tqdm(extract_img_job_crop_ravdess, vid_names, num_workers=args.num_workers, desc="resampling videos"):
|
80 |
+
pass
|
81 |
+
elif ds_name == "CMLR":
|
82 |
+
for i, res in multiprocess_run_tqdm(extract_img_job_crop, vid_names, num_workers=args.num_workers, desc="resampling videos"):
|
83 |
+
pass
|
84 |
+
else:
|
85 |
+
for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="resampling videos"):
|
86 |
+
pass
|
87 |
+
|
data_gen/utils/process_video/split_video_to_imgs.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
|
3 |
+
|
4 |
+
from data_gen.utils.path_converter import PathConverter, pc
|
5 |
+
|
6 |
+
# mp4_names = glob.glob("/home/tiger/datasets/raw/CelebV-HQ/video/*.mp4")
|
7 |
+
|
8 |
+
def extract_img_job(video_name, raw_img_dir=None):
|
9 |
+
if raw_img_dir is not None:
|
10 |
+
out_path = raw_img_dir
|
11 |
+
else:
|
12 |
+
out_path = pc.to(video_name.replace(".mp4", ""), "vid", "gt")
|
13 |
+
os.makedirs(out_path, exist_ok=True)
|
14 |
+
ffmpeg_path = "/usr/bin/ffmpeg"
|
15 |
+
cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet {os.path.join(out_path, "%8d.jpg")}'
|
16 |
+
os.system(cmd)
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
import argparse, glob, tqdm, random
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
|
22 |
+
parser.add_argument("--ds_name", default='CelebV-HQ')
|
23 |
+
parser.add_argument("--num_workers", default=64, type=int)
|
24 |
+
parser.add_argument("--process_id", default=0, type=int)
|
25 |
+
parser.add_argument("--total_process", default=1, type=int)
|
26 |
+
args = parser.parse_args()
|
27 |
+
vid_dir = args.vid_dir
|
28 |
+
ds_name = args.ds_name
|
29 |
+
if ds_name in ['lrs3_trainval']:
|
30 |
+
mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
|
31 |
+
elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
|
32 |
+
vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
|
33 |
+
elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
|
34 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
|
35 |
+
vid_names = glob.glob(vid_name_pattern)
|
36 |
+
elif ds_name in ["RAVDESS", 'VFHQ']:
|
37 |
+
vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
|
38 |
+
vid_names = glob.glob(vid_name_pattern)
|
39 |
+
vid_names = sorted(vid_names)
|
40 |
+
|
41 |
+
process_id = args.process_id
|
42 |
+
total_process = args.total_process
|
43 |
+
if total_process > 1:
|
44 |
+
assert process_id <= total_process -1
|
45 |
+
num_samples_per_process = len(vid_names) // total_process
|
46 |
+
if process_id == total_process:
|
47 |
+
vid_names = vid_names[process_id * num_samples_per_process : ]
|
48 |
+
else:
|
49 |
+
vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
|
50 |
+
|
51 |
+
for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="extracting images"):
|
52 |
+
pass
|
53 |
+
|
data_util/face3d_helper.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from scipy.io import loadmat
|
6 |
+
|
7 |
+
from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
|
8 |
+
|
9 |
+
|
10 |
+
class Face3DHelper(nn.Module):
|
11 |
+
def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
|
12 |
+
super().__init__()
|
13 |
+
self.keypoint_mode = keypoint_mode # lm68 | mediapipe
|
14 |
+
self.bfm_dir = bfm_dir
|
15 |
+
self.load_3dmm()
|
16 |
+
if use_gpu: self.to("cuda")
|
17 |
+
|
18 |
+
def load_3dmm(self):
|
19 |
+
model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
|
20 |
+
self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
|
21 |
+
mean_shape = self.mean_shape.reshape([-1, 3])
|
22 |
+
# re-center
|
23 |
+
mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
|
24 |
+
self.mean_shape = mean_shape.reshape([-1, 1])
|
25 |
+
self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
|
26 |
+
self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
|
27 |
+
|
28 |
+
self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
|
29 |
+
self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
|
30 |
+
|
31 |
+
self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
|
32 |
+
self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
|
33 |
+
if self.keypoint_mode == 'mediapipe':
|
34 |
+
self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
|
35 |
+
unmatch_mask = self.key_points < 0
|
36 |
+
self.key_points[unmatch_mask] = 0
|
37 |
+
else:
|
38 |
+
self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
|
39 |
+
|
40 |
+
|
41 |
+
self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
|
42 |
+
self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
|
43 |
+
self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
|
44 |
+
self.key_id_base_np = self.key_id_base.cpu().numpy()
|
45 |
+
self.key_exp_base_np = self.key_exp_base.cpu().numpy()
|
46 |
+
|
47 |
+
self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
|
48 |
+
def split_coeff(self, coeff):
|
49 |
+
"""
|
50 |
+
coeff: Tensor[B, T, c=257] or [T, c=257]
|
51 |
+
"""
|
52 |
+
ret_dict = {
|
53 |
+
'identity': coeff[..., :80], # identity, [b, t, c=80]
|
54 |
+
'expression': coeff[..., 80:144], # expression, [b, t, c=80]
|
55 |
+
'texture': coeff[..., 144:224], # texture, [b, t, c=80]
|
56 |
+
'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
|
57 |
+
'translation': coeff[..., 254:257], # translation, [b, t, c=3]
|
58 |
+
'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
|
59 |
+
}
|
60 |
+
return ret_dict
|
61 |
+
|
62 |
+
def reconstruct_face_mesh(self, id_coeff, exp_coeff):
|
63 |
+
"""
|
64 |
+
Generate a pose-independent 3D face mesh!
|
65 |
+
id_coeff: Tensor[T, c=80]
|
66 |
+
exp_coeff: Tensor[T, c=64]
|
67 |
+
"""
|
68 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
69 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
70 |
+
mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
|
71 |
+
id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
|
72 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
|
73 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
|
74 |
+
|
75 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
76 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
77 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
78 |
+
# mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
|
79 |
+
# face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
|
80 |
+
return face
|
81 |
+
|
82 |
+
def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
|
83 |
+
"""
|
84 |
+
Generate 3D landmark with keypoint base!
|
85 |
+
id_coeff: Tensor[T, c=80]
|
86 |
+
exp_coeff: Tensor[T, c=64]
|
87 |
+
"""
|
88 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
89 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
90 |
+
mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
|
91 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
92 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
93 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
94 |
+
|
95 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
96 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
97 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
98 |
+
# mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
|
99 |
+
# lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
|
100 |
+
return face
|
101 |
+
|
102 |
+
def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
|
103 |
+
"""
|
104 |
+
Generate 3D landmark with keypoint base!
|
105 |
+
id_coeff: Tensor[T, c=80]
|
106 |
+
exp_coeff: Tensor[T, c=64]
|
107 |
+
"""
|
108 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
109 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
110 |
+
mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
|
111 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
112 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
113 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
114 |
+
|
115 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
116 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
117 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
118 |
+
rot = self.compute_rotation(euler)
|
119 |
+
# transform
|
120 |
+
lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
|
121 |
+
# to camera
|
122 |
+
if to_camera:
|
123 |
+
lm3d[...,-1] = 10 - lm3d[...,-1]
|
124 |
+
return lm3d
|
125 |
+
|
126 |
+
def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
|
127 |
+
lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
|
128 |
+
lm2d[..., 0] = 1 - lm2d[..., 0]
|
129 |
+
lm2d[..., 1] = 1 - lm2d[..., 1]
|
130 |
+
return lm2d
|
131 |
+
|
132 |
+
def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
|
133 |
+
"""
|
134 |
+
Generate 3D landmark with keypoint base!
|
135 |
+
id_coeff: Tensor[T, c=80]
|
136 |
+
exp_coeff: Tensor[T, c=64]
|
137 |
+
"""
|
138 |
+
is_btc_flag = True if id_coeff.ndim == 3 else False
|
139 |
+
if is_btc_flag:
|
140 |
+
b,t,_ = id_coeff.shape
|
141 |
+
id_coeff = id_coeff.reshape([b*t,-1])
|
142 |
+
exp_coeff = exp_coeff.reshape([b*t,-1])
|
143 |
+
euler = euler.reshape([b*t,-1])
|
144 |
+
trans = trans.reshape([b*t,-1])
|
145 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
146 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
147 |
+
mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
|
148 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
149 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
150 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
151 |
+
|
152 |
+
face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
|
153 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
154 |
+
# re-centering the face with mean_xyz, so the face will be in [-1, 1]
|
155 |
+
rot = self.compute_rotation(euler)
|
156 |
+
# transform
|
157 |
+
lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
|
158 |
+
# to camera
|
159 |
+
if to_camera:
|
160 |
+
lm3d[...,-1] = 10 - lm3d[...,-1]
|
161 |
+
# to image_plane
|
162 |
+
lm3d = lm3d @ self.persc_proj
|
163 |
+
lm2d = lm3d[..., :2] / lm3d[..., 2:]
|
164 |
+
# flip
|
165 |
+
lm2d[..., 1] = 224 - lm2d[..., 1]
|
166 |
+
lm2d /= 224
|
167 |
+
if is_btc_flag:
|
168 |
+
return lm2d.reshape([b,t,-1,2])
|
169 |
+
return lm2d
|
170 |
+
|
171 |
+
def compute_rotation(self, euler):
|
172 |
+
"""
|
173 |
+
Return:
|
174 |
+
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
|
175 |
+
|
176 |
+
Parameters:
|
177 |
+
euler -- torch.tensor, size (B, 3), radian
|
178 |
+
"""
|
179 |
+
|
180 |
+
batch_size = euler.shape[0]
|
181 |
+
euler = euler.to(self.key_id_base.device)
|
182 |
+
ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
|
183 |
+
zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
|
184 |
+
x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
|
185 |
+
|
186 |
+
rot_x = torch.cat([
|
187 |
+
ones, zeros, zeros,
|
188 |
+
zeros, torch.cos(x), -torch.sin(x),
|
189 |
+
zeros, torch.sin(x), torch.cos(x)
|
190 |
+
], dim=1).reshape([batch_size, 3, 3])
|
191 |
+
|
192 |
+
rot_y = torch.cat([
|
193 |
+
torch.cos(y), zeros, torch.sin(y),
|
194 |
+
zeros, ones, zeros,
|
195 |
+
-torch.sin(y), zeros, torch.cos(y)
|
196 |
+
], dim=1).reshape([batch_size, 3, 3])
|
197 |
+
|
198 |
+
rot_z = torch.cat([
|
199 |
+
torch.cos(z), -torch.sin(z), zeros,
|
200 |
+
torch.sin(z), torch.cos(z), zeros,
|
201 |
+
zeros, zeros, ones
|
202 |
+
], dim=1).reshape([batch_size, 3, 3])
|
203 |
+
|
204 |
+
rot = rot_z @ rot_y @ rot_x
|
205 |
+
return rot.permute(0, 2, 1)
|
206 |
+
|
207 |
+
def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
|
208 |
+
"""
|
209 |
+
Generate 3D landmark with keypoint base!
|
210 |
+
id_coeff: Tensor[T, c=80]
|
211 |
+
exp_coeff: Tensor[T, c=64]
|
212 |
+
"""
|
213 |
+
id_coeff = id_coeff.to(self.key_id_base.device)
|
214 |
+
exp_coeff = exp_coeff.to(self.key_id_base.device)
|
215 |
+
id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
|
216 |
+
identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
217 |
+
expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
|
218 |
+
|
219 |
+
face = identity_diff_face + expression_diff_face # [t,3N]
|
220 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
221 |
+
lm3d = face * 10
|
222 |
+
return lm3d
|
223 |
+
|
224 |
+
def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
|
225 |
+
"""
|
226 |
+
Generate 3D landmark with keypoint base!
|
227 |
+
id_coeff: Tensor[T, c=80]
|
228 |
+
exp_coeff: Tensor[T, c=64]
|
229 |
+
"""
|
230 |
+
id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
|
231 |
+
identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
|
232 |
+
expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
|
233 |
+
|
234 |
+
face = identity_diff_face + expression_diff_face # [t,3N]
|
235 |
+
face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
|
236 |
+
lm3d = face * 10
|
237 |
+
return lm3d
|
238 |
+
|
239 |
+
def get_eye_mouth_lm_from_lm3d(self, lm3d):
|
240 |
+
eye_lm = lm3d[:, 17:48] # [T, 31, 3]
|
241 |
+
mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
|
242 |
+
return eye_lm, mouth_lm
|
243 |
+
|
244 |
+
def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
|
245 |
+
eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
|
246 |
+
mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
|
247 |
+
return eye_lm, mouth_lm
|
248 |
+
|
249 |
+
def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
|
250 |
+
idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
|
251 |
+
num_frames = idexp_lm3d.shape[0]
|
252 |
+
eps = 0.0
|
253 |
+
# [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
|
254 |
+
idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
|
255 |
+
idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
|
256 |
+
|
257 |
+
idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
|
258 |
+
idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
|
259 |
+
|
260 |
+
idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
|
261 |
+
idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
|
262 |
+
|
263 |
+
if freeze_as_first_frame:
|
264 |
+
idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
|
265 |
+
return idexp_lm3d.cpu()
|
266 |
+
|
267 |
+
def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
|
268 |
+
idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
|
269 |
+
eps = 0.003
|
270 |
+
idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
|
271 |
+
idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
|
272 |
+
|
273 |
+
idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
|
274 |
+
idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
|
275 |
+
|
276 |
+
return idexp_lm3d
|
277 |
+
|
278 |
+
if __name__ == '__main__':
|
279 |
+
import cv2
|
280 |
+
|
281 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
282 |
+
|
283 |
+
face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
|
284 |
+
coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
|
285 |
+
coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
|
286 |
+
lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
|
287 |
+
|
288 |
+
WH = 512
|
289 |
+
lm3d = (lm3d * WH).cpu().int().numpy()
|
290 |
+
eye_idx = list(range(36,48))
|
291 |
+
mouth_idx = list(range(48,68))
|
292 |
+
import imageio
|
293 |
+
debug_name = 'debug_lm3d.mp4'
|
294 |
+
writer = imageio.get_writer(debug_name, fps=25)
|
295 |
+
for i_img in range(len(lm3d)):
|
296 |
+
lm2d = lm3d[i_img ,:, :2] # [68, 2]
|
297 |
+
img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
|
298 |
+
for i in range(len(lm2d)):
|
299 |
+
x, y = lm2d[i]
|
300 |
+
if i in eye_idx:
|
301 |
+
color = (0,0,255)
|
302 |
+
elif i in mouth_idx:
|
303 |
+
color = (0,255,0)
|
304 |
+
else:
|
305 |
+
color = (255,0,0)
|
306 |
+
img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
|
307 |
+
img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
|
308 |
+
writer.append_data(img)
|
309 |
+
writer.close()
|
deep_3drecon/BFM/.gitkeep
ADDED
File without changes
|
deep_3drecon/bfm_left_eye_faces.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9651756ea2c0fac069a1edf858ed1f125eddc358fa74c529a370c1e7b5730d28
|
3 |
+
size 4680
|
deep_3drecon/bfm_right_eye_faces.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28cb5bbacf578d30a3d5006ec28c617fe5a3ecaeeeb87d9433a884e0f0301a2e
|
3 |
+
size 4648
|
deep_3drecon/deep_3drecon_models/bfm.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from scipy.io import loadmat
|
8 |
+
import os
|
9 |
+
# from utils.commons.tensor_utils import convert_like
|
10 |
+
|
11 |
+
|
12 |
+
def perspective_projection(focal, center):
|
13 |
+
# return p.T (N, 3) @ (3, 3)
|
14 |
+
return np.array([
|
15 |
+
focal, 0, center,
|
16 |
+
0, focal, center,
|
17 |
+
0, 0, 1
|
18 |
+
]).reshape([3, 3]).astype(np.float32).transpose() # 注意这里的transpose!
|
19 |
+
|
20 |
+
class SH:
|
21 |
+
def __init__(self):
|
22 |
+
self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
|
23 |
+
self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
class ParametricFaceModel:
|
28 |
+
def __init__(self,
|
29 |
+
bfm_folder='./BFM',
|
30 |
+
recenter=True,
|
31 |
+
camera_distance=10.,
|
32 |
+
init_lit=np.array([
|
33 |
+
0.8, 0, 0, 0, 0, 0, 0, 0, 0
|
34 |
+
]),
|
35 |
+
focal=1015.,
|
36 |
+
center=112.,
|
37 |
+
is_train=True,
|
38 |
+
default_name='BFM_model_front.mat',
|
39 |
+
keypoint_mode='mediapipe'):
|
40 |
+
|
41 |
+
model = loadmat(os.path.join(bfm_folder, default_name))
|
42 |
+
# mean face shape. [3*N,1]
|
43 |
+
self.mean_shape = model['meanshape'].astype(np.float32)
|
44 |
+
# identity basis. [3*N,80]
|
45 |
+
self.id_base = model['idBase'].astype(np.float32)
|
46 |
+
# expression basis. [3*N,64]
|
47 |
+
self.exp_base = model['exBase'].astype(np.float32)
|
48 |
+
# mean face texture. [3*N,1] (0-255)
|
49 |
+
self.mean_tex = model['meantex'].astype(np.float32)
|
50 |
+
# texture basis. [3*N,80]
|
51 |
+
self.tex_base = model['texBase'].astype(np.float32)
|
52 |
+
# face indices for each vertex that lies in. starts from 0. [N,8]
|
53 |
+
self.point_buf = model['point_buf'].astype(np.int64) - 1
|
54 |
+
# vertex indices for each face. starts from 0. [F,3]
|
55 |
+
self.face_buf = model['tri'].astype(np.int64) - 1
|
56 |
+
# vertex indices for 68 landmarks. starts from 0. [68,1]
|
57 |
+
if keypoint_mode == 'mediapipe':
|
58 |
+
self.keypoints = np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)
|
59 |
+
unmatch_mask = self.keypoints < 0
|
60 |
+
self.keypoints[unmatch_mask] = 0
|
61 |
+
else:
|
62 |
+
self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
|
63 |
+
|
64 |
+
if is_train:
|
65 |
+
# vertex indices for small face region to compute photometric error. starts from 0.
|
66 |
+
self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
|
67 |
+
# vertex indices for each face from small face region. starts from 0. [f,3]
|
68 |
+
self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
|
69 |
+
# vertex indices for pre-defined skin region to compute reflectance loss
|
70 |
+
self.skin_mask = np.squeeze(model['skinmask'])
|
71 |
+
|
72 |
+
if recenter:
|
73 |
+
mean_shape = self.mean_shape.reshape([-1, 3])
|
74 |
+
mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
|
75 |
+
self.mean_shape = mean_shape.reshape([-1, 1])
|
76 |
+
|
77 |
+
self.key_mean_shape = self.mean_shape.reshape([-1, 3])[self.keypoints, :].reshape([-1, 3])
|
78 |
+
self.key_id_base = self.id_base.reshape([-1, 3,80])[self.keypoints, :].reshape([-1, 80])
|
79 |
+
self.key_exp_base = self.exp_base.reshape([-1, 3, 64])[self.keypoints, :].reshape([-1, 64])
|
80 |
+
|
81 |
+
self.focal = focal
|
82 |
+
self.center = center
|
83 |
+
self.persc_proj = perspective_projection(focal, center)
|
84 |
+
self.device = 'cpu'
|
85 |
+
self.camera_distance = camera_distance
|
86 |
+
self.SH = SH()
|
87 |
+
self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
|
88 |
+
|
89 |
+
self.initialized = False
|
90 |
+
|
91 |
+
def to(self, device):
|
92 |
+
self.device = device
|
93 |
+
for key, value in self.__dict__.items():
|
94 |
+
if type(value).__module__ == np.__name__:
|
95 |
+
setattr(self, key, torch.tensor(value).to(device))
|
96 |
+
self.initialized = True
|
97 |
+
return self
|
98 |
+
|
99 |
+
def compute_shape(self, id_coeff, exp_coeff):
|
100 |
+
"""
|
101 |
+
Return:
|
102 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
103 |
+
|
104 |
+
Parameters:
|
105 |
+
id_coeff -- torch.tensor, size (B, 80), identity coeffs
|
106 |
+
exp_coeff -- torch.tensor, size (B, 64), expression coeffs
|
107 |
+
"""
|
108 |
+
batch_size = id_coeff.shape[0]
|
109 |
+
id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
|
110 |
+
exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
|
111 |
+
face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
|
112 |
+
return face_shape.reshape([batch_size, -1, 3])
|
113 |
+
|
114 |
+
def compute_key_shape(self, id_coeff, exp_coeff):
|
115 |
+
"""
|
116 |
+
Return:
|
117 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
118 |
+
|
119 |
+
Parameters:
|
120 |
+
id_coeff -- torch.tensor, size (B, 80), identity coeffs
|
121 |
+
exp_coeff -- torch.tensor, size (B, 64), expression coeffs
|
122 |
+
"""
|
123 |
+
batch_size = id_coeff.shape[0]
|
124 |
+
id_part = torch.einsum('ij,aj->ai', self.key_id_base, id_coeff)
|
125 |
+
exp_part = torch.einsum('ij,aj->ai', self.key_exp_base, exp_coeff)
|
126 |
+
face_shape = id_part + exp_part + self.key_mean_shape.reshape([1, -1])
|
127 |
+
return face_shape.reshape([batch_size, -1, 3])
|
128 |
+
|
129 |
+
def compute_texture(self, tex_coeff, normalize=True):
|
130 |
+
"""
|
131 |
+
Return:
|
132 |
+
face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
|
133 |
+
|
134 |
+
Parameters:
|
135 |
+
tex_coeff -- torch.tensor, size (B, 80)
|
136 |
+
"""
|
137 |
+
batch_size = tex_coeff.shape[0]
|
138 |
+
face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
|
139 |
+
if normalize:
|
140 |
+
face_texture = face_texture / 255.
|
141 |
+
return face_texture.reshape([batch_size, -1, 3])
|
142 |
+
|
143 |
+
|
144 |
+
def compute_norm(self, face_shape):
|
145 |
+
"""
|
146 |
+
Return:
|
147 |
+
vertex_norm -- torch.tensor, size (B, N, 3)
|
148 |
+
|
149 |
+
Parameters:
|
150 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
151 |
+
"""
|
152 |
+
|
153 |
+
v1 = face_shape[:, self.face_buf[:, 0]]
|
154 |
+
v2 = face_shape[:, self.face_buf[:, 1]]
|
155 |
+
v3 = face_shape[:, self.face_buf[:, 2]]
|
156 |
+
e1 = v1 - v2
|
157 |
+
e2 = v2 - v3
|
158 |
+
face_norm = torch.cross(e1, e2, dim=-1)
|
159 |
+
face_norm = F.normalize(face_norm, dim=-1, p=2)
|
160 |
+
face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
|
161 |
+
|
162 |
+
vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
|
163 |
+
vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
|
164 |
+
return vertex_norm
|
165 |
+
|
166 |
+
|
167 |
+
def compute_color(self, face_texture, face_norm, gamma):
|
168 |
+
"""
|
169 |
+
Return:
|
170 |
+
face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
|
171 |
+
|
172 |
+
Parameters:
|
173 |
+
face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
|
174 |
+
face_norm -- torch.tensor, size (B, N, 3), rotated face normal
|
175 |
+
gamma -- torch.tensor, size (B, 27), SH coeffs
|
176 |
+
"""
|
177 |
+
batch_size = gamma.shape[0]
|
178 |
+
v_num = face_texture.shape[1]
|
179 |
+
a, c = self.SH.a, self.SH.c
|
180 |
+
gamma = gamma.reshape([batch_size, 3, 9])
|
181 |
+
gamma = gamma + self.init_lit
|
182 |
+
gamma = gamma.permute(0, 2, 1)
|
183 |
+
Y = torch.cat([
|
184 |
+
a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
|
185 |
+
-a[1] * c[1] * face_norm[..., 1:2],
|
186 |
+
a[1] * c[1] * face_norm[..., 2:],
|
187 |
+
-a[1] * c[1] * face_norm[..., :1],
|
188 |
+
a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
|
189 |
+
-a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
|
190 |
+
0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
|
191 |
+
-a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
|
192 |
+
0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
|
193 |
+
], dim=-1)
|
194 |
+
r = Y @ gamma[..., :1]
|
195 |
+
g = Y @ gamma[..., 1:2]
|
196 |
+
b = Y @ gamma[..., 2:]
|
197 |
+
face_color = torch.cat([r, g, b], dim=-1) * face_texture
|
198 |
+
return face_color
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def compute_rotation(angles, device='cpu'):
|
202 |
+
"""
|
203 |
+
Return:
|
204 |
+
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
|
205 |
+
|
206 |
+
Parameters:
|
207 |
+
angles -- torch.tensor, size (B, 3), radian
|
208 |
+
"""
|
209 |
+
|
210 |
+
batch_size = angles.shape[0]
|
211 |
+
angles = angles.to(device)
|
212 |
+
ones = torch.ones([batch_size, 1]).to(device)
|
213 |
+
zeros = torch.zeros([batch_size, 1]).to(device)
|
214 |
+
x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
|
215 |
+
|
216 |
+
rot_x = torch.cat([
|
217 |
+
ones, zeros, zeros,
|
218 |
+
zeros, torch.cos(x), -torch.sin(x),
|
219 |
+
zeros, torch.sin(x), torch.cos(x)
|
220 |
+
], dim=1).reshape([batch_size, 3, 3])
|
221 |
+
|
222 |
+
rot_y = torch.cat([
|
223 |
+
torch.cos(y), zeros, torch.sin(y),
|
224 |
+
zeros, ones, zeros,
|
225 |
+
-torch.sin(y), zeros, torch.cos(y)
|
226 |
+
], dim=1).reshape([batch_size, 3, 3])
|
227 |
+
|
228 |
+
rot_z = torch.cat([
|
229 |
+
torch.cos(z), -torch.sin(z), zeros,
|
230 |
+
torch.sin(z), torch.cos(z), zeros,
|
231 |
+
zeros, zeros, ones
|
232 |
+
], dim=1).reshape([batch_size, 3, 3])
|
233 |
+
|
234 |
+
rot = rot_z @ rot_y @ rot_x
|
235 |
+
return rot.permute(0, 2, 1)
|
236 |
+
|
237 |
+
|
238 |
+
def to_camera(self, face_shape):
|
239 |
+
face_shape[..., -1] = self.camera_distance - face_shape[..., -1] # reverse the depth axis, add a fixed offset of length
|
240 |
+
return face_shape
|
241 |
+
|
242 |
+
def to_image(self, face_shape):
|
243 |
+
"""
|
244 |
+
Return:
|
245 |
+
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
|
246 |
+
|
247 |
+
Parameters:
|
248 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
249 |
+
"""
|
250 |
+
# to image_plane
|
251 |
+
face_proj = face_shape @ self.persc_proj
|
252 |
+
face_proj = face_proj[..., :2] / face_proj[..., 2:]
|
253 |
+
|
254 |
+
return face_proj
|
255 |
+
|
256 |
+
|
257 |
+
def transform(self, face_shape, rot, trans):
|
258 |
+
"""
|
259 |
+
Return:
|
260 |
+
face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
|
261 |
+
|
262 |
+
Parameters:
|
263 |
+
face_shape -- torch.tensor, si≥ze (B, N, 3)
|
264 |
+
rot -- torch.tensor, size (B, 3, 3)
|
265 |
+
trans -- torch.tensor, size (B, 3)
|
266 |
+
"""
|
267 |
+
return face_shape @ rot + trans.unsqueeze(1)
|
268 |
+
|
269 |
+
|
270 |
+
def get_landmarks(self, face_proj):
|
271 |
+
"""
|
272 |
+
Return:
|
273 |
+
face_lms -- torch.tensor, size (B, 68, 2)
|
274 |
+
|
275 |
+
Parameters:
|
276 |
+
face_proj -- torch.tensor, size (B, N, 2)
|
277 |
+
"""
|
278 |
+
return face_proj[:, self.keypoints]
|
279 |
+
|
280 |
+
def split_coeff(self, coeffs):
|
281 |
+
"""
|
282 |
+
Return:
|
283 |
+
coeffs_dict -- a dict of torch.tensors
|
284 |
+
|
285 |
+
Parameters:
|
286 |
+
coeffs -- torch.tensor, size (B, 256)
|
287 |
+
"""
|
288 |
+
id_coeffs = coeffs[:, :80]
|
289 |
+
exp_coeffs = coeffs[:, 80: 144]
|
290 |
+
tex_coeffs = coeffs[:, 144: 224]
|
291 |
+
angles = coeffs[:, 224: 227]
|
292 |
+
gammas = coeffs[:, 227: 254]
|
293 |
+
translations = coeffs[:, 254:]
|
294 |
+
return {
|
295 |
+
'id': id_coeffs,
|
296 |
+
'exp': exp_coeffs,
|
297 |
+
'tex': tex_coeffs,
|
298 |
+
'angle': angles,
|
299 |
+
'gamma': gammas,
|
300 |
+
'trans': translations
|
301 |
+
}
|
302 |
+
def compute_for_render(self, coeffs):
|
303 |
+
"""
|
304 |
+
Return:
|
305 |
+
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
306 |
+
face_color -- torch.tensor, size (B, N, 3), in RGB order
|
307 |
+
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
308 |
+
Parameters:
|
309 |
+
coeffs -- torch.tensor, size (B, 257)
|
310 |
+
"""
|
311 |
+
coef_dict = self.split_coeff(coeffs)
|
312 |
+
face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
|
313 |
+
rotation = self.compute_rotation(coef_dict['angle'], device=self.device)
|
314 |
+
|
315 |
+
|
316 |
+
face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
|
317 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
318 |
+
|
319 |
+
face_proj = self.to_image(face_vertex)
|
320 |
+
landmark = self.get_landmarks(face_proj)
|
321 |
+
|
322 |
+
face_texture = self.compute_texture(coef_dict['tex'])
|
323 |
+
face_norm = self.compute_norm(face_shape)
|
324 |
+
face_norm_roted = face_norm @ rotation
|
325 |
+
face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
|
326 |
+
|
327 |
+
return face_vertex, face_texture, face_color, landmark
|
328 |
+
|
329 |
+
def compute_face_vertex(self, id, exp, angle, trans):
|
330 |
+
"""
|
331 |
+
Return:
|
332 |
+
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
333 |
+
face_color -- torch.tensor, size (B, N, 3), in RGB order
|
334 |
+
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
335 |
+
Parameters:
|
336 |
+
coeffs -- torch.tensor, size (B, 257)
|
337 |
+
"""
|
338 |
+
if not self.initialized:
|
339 |
+
self.to(id.device)
|
340 |
+
face_shape = self.compute_shape(id, exp)
|
341 |
+
rotation = self.compute_rotation(angle, device=self.device)
|
342 |
+
face_shape_transformed = self.transform(face_shape, rotation, trans)
|
343 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
344 |
+
return face_vertex
|
345 |
+
|
346 |
+
def compute_for_landmark_fit(self, id, exp, angles, trans, ret=None):
|
347 |
+
"""
|
348 |
+
Return:
|
349 |
+
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
350 |
+
face_color -- torch.tensor, size (B, N, 3), in RGB order
|
351 |
+
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
352 |
+
Parameters:
|
353 |
+
coeffs -- torch.tensor, size (B, 257)
|
354 |
+
"""
|
355 |
+
face_shape = self.compute_key_shape(id, exp)
|
356 |
+
rotation = self.compute_rotation(angles, device=self.device)
|
357 |
+
|
358 |
+
face_shape_transformed = self.transform(face_shape, rotation, trans)
|
359 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
360 |
+
|
361 |
+
face_proj = self.to_image(face_vertex)
|
362 |
+
landmark = face_proj
|
363 |
+
return landmark
|
364 |
+
|
365 |
+
def compute_for_landmark_fit_nerf(self, id, exp, angles, trans, ret=None):
|
366 |
+
"""
|
367 |
+
Return:
|
368 |
+
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
369 |
+
face_color -- torch.tensor, size (B, N, 3), in RGB order
|
370 |
+
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
371 |
+
Parameters:
|
372 |
+
coeffs -- torch.tensor, size (B, 257)
|
373 |
+
"""
|
374 |
+
face_shape = self.compute_key_shape(id, exp)
|
375 |
+
rotation = self.compute_rotation(angles, device=self.device)
|
376 |
+
|
377 |
+
face_shape_transformed = self.transform(face_shape, rotation, trans)
|
378 |
+
face_vertex = face_shape_transformed # no to_camera
|
379 |
+
|
380 |
+
face_proj = self.to_image(face_vertex)
|
381 |
+
landmark = face_proj
|
382 |
+
return landmark
|
383 |
+
|
384 |
+
# def compute_for_landmark_fit(self, id, exp, angles, trans, ret={}):
|
385 |
+
# """
|
386 |
+
# Return:
|
387 |
+
# face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
388 |
+
# face_color -- torch.tensor, size (B, N, 3), in RGB order
|
389 |
+
# landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
390 |
+
# Parameters:
|
391 |
+
# coeffs -- torch.tensor, size (B, 257)
|
392 |
+
# """
|
393 |
+
# face_shape = self.compute_shape(id, exp)
|
394 |
+
# rotation = self.compute_rotation(angles)
|
395 |
+
|
396 |
+
# face_shape_transformed = self.transform(face_shape, rotation, trans)
|
397 |
+
# face_vertex = self.to_camera(face_shape_transformed)
|
398 |
+
|
399 |
+
# face_proj = self.to_image(face_vertex)
|
400 |
+
# landmark = self.get_landmarks(face_proj)
|
401 |
+
# return landmark
|
402 |
+
|
403 |
+
def compute_for_render_fit(self, id, exp, angles, trans, tex, gamma):
|
404 |
+
"""
|
405 |
+
Return:
|
406 |
+
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
407 |
+
face_color -- torch.tensor, size (B, N, 3), in RGB order
|
408 |
+
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
409 |
+
Parameters:
|
410 |
+
coeffs -- torch.tensor, size (B, 257)
|
411 |
+
"""
|
412 |
+
face_shape = self.compute_shape(id, exp)
|
413 |
+
rotation = self.compute_rotation(angles, device=self.device)
|
414 |
+
|
415 |
+
face_shape_transformed = self.transform(face_shape, rotation, trans)
|
416 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
417 |
+
|
418 |
+
face_proj = self.to_image(face_vertex)
|
419 |
+
landmark = self.get_landmarks(face_proj)
|
420 |
+
|
421 |
+
face_texture = self.compute_texture(tex)
|
422 |
+
face_norm = self.compute_norm(face_shape)
|
423 |
+
face_norm_roted = face_norm @ rotation
|
424 |
+
face_color = self.compute_color(face_texture, face_norm_roted, gamma)
|
425 |
+
|
426 |
+
return face_color, face_vertex, landmark
|
deep_3drecon/ncc_code.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da54a620c0981d43cc9f30b3d8b3f5d4beb0ec0e27127a1ef3fb62ea50913609
|
3 |
+
size 428636
|
deep_3drecon/secc_renderer.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
from deep_3drecon.util.mesh_renderer import MeshRenderer
|
7 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
8 |
+
|
9 |
+
|
10 |
+
class SECC_Renderer(nn.Module):
|
11 |
+
def __init__(self, rasterize_size=None, device="cuda"):
|
12 |
+
super().__init__()
|
13 |
+
self.face_model = ParametricFaceModel('deep_3drecon/BFM')
|
14 |
+
self.fov = 2 * np.arctan(self.face_model.center / self.face_model.focal) * 180 / np.pi
|
15 |
+
self.znear = 5.
|
16 |
+
self.zfar = 15.
|
17 |
+
if rasterize_size is None:
|
18 |
+
rasterize_size = 2*self.face_model.center
|
19 |
+
self.face_renderer = MeshRenderer(rasterize_fov=self.fov, znear=self.znear, zfar=self.zfar, rasterize_size=rasterize_size, use_opengl=False).cuda()
|
20 |
+
face_feat = np.load("deep_3drecon/ncc_code.npy", allow_pickle=True)
|
21 |
+
self.face_feat = torch.tensor(face_feat.T).unsqueeze(0).to(device=device)
|
22 |
+
|
23 |
+
del_index_re = np.load('deep_3drecon/bfm_right_eye_faces.npy')
|
24 |
+
del_index_re = del_index_re - 1
|
25 |
+
del_index_le = np.load('deep_3drecon/bfm_left_eye_faces.npy')
|
26 |
+
del_index_le = del_index_le - 1
|
27 |
+
face_buf_list = []
|
28 |
+
for i in range(self.face_model.face_buf.shape[0]):
|
29 |
+
if i not in del_index_re and i not in del_index_le:
|
30 |
+
face_buf_list.append(self.face_model.face_buf[i])
|
31 |
+
face_buf_arr = np.array(face_buf_list)
|
32 |
+
self.face_buf = torch.tensor(face_buf_arr).to(device=device)
|
33 |
+
|
34 |
+
def forward(self, id, exp, euler, trans):
|
35 |
+
"""
|
36 |
+
id, exp, euler, euler: [B, C] or [B, T, C]
|
37 |
+
return:
|
38 |
+
MASK: [B, 1, 512, 512], value[0. or 1.0], 1.0 denotes is face
|
39 |
+
SECC MAP: [B, 3, 512, 512], value[0~1]
|
40 |
+
if input is BTC format, return [B, C, T, H, W]
|
41 |
+
"""
|
42 |
+
bs = id.shape[0]
|
43 |
+
is_btc_flag = id.ndim == 3
|
44 |
+
if is_btc_flag:
|
45 |
+
t = id.shape[1]
|
46 |
+
bs = bs * t
|
47 |
+
id, exp, euler, trans = id.reshape([bs,-1]), exp.reshape([bs,-1]), euler.reshape([bs,-1]), trans.reshape([bs,-1])
|
48 |
+
|
49 |
+
face_vertex = self.face_model.compute_face_vertex(id, exp, euler, trans)
|
50 |
+
face_mask, _, secc_face = self.face_renderer(
|
51 |
+
face_vertex, self.face_buf.unsqueeze(0).repeat([bs, 1, 1]), feat=self.face_feat.repeat([bs,1,1]))
|
52 |
+
secc_face = (secc_face - 0.5) / 0.5 # scale to -1~1
|
53 |
+
|
54 |
+
if is_btc_flag:
|
55 |
+
bs = bs // t
|
56 |
+
face_mask = rearrange(face_mask, "(n t) c h w -> n c t h w", n=bs, t=t)
|
57 |
+
secc_face = rearrange(secc_face, "(n t) c h w -> n c t h w", n=bs, t=t)
|
58 |
+
return face_mask, secc_face
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
import imageio
|
63 |
+
|
64 |
+
renderer = SECC_Renderer(rasterize_size=512)
|
65 |
+
ret = np.load("data/processed/videos/May/vid_coeff_fit.npy", allow_pickle=True).tolist()
|
66 |
+
idx = 6
|
67 |
+
id = torch.tensor(ret['id']).cuda()[idx:idx+1]
|
68 |
+
exp = torch.tensor(ret['exp']).cuda()[idx:idx+1]
|
69 |
+
angle = torch.tensor(ret['euler']).cuda()[idx:idx+1]
|
70 |
+
trans = torch.tensor(ret['trans']).cuda()[idx:idx+1]
|
71 |
+
mask, secc = renderer(id, exp, angle*0, trans*0) # [1, 1, 512, 512], [1, 3, 512, 512]
|
72 |
+
|
73 |
+
out_mask = mask[0].permute(1,2,0)
|
74 |
+
out_mask = (out_mask * 127.5 + 127.5).int().cpu().numpy()
|
75 |
+
imageio.imwrite("out_mask.png", out_mask)
|
76 |
+
out_img = secc[0].permute(1,2,0)
|
77 |
+
out_img = (out_img * 127.5 + 127.5).int().cpu().numpy()
|
78 |
+
imageio.imwrite("out_secc.png", out_img)
|
deep_3drecon/util/mesh_renderer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
|
2 |
+
Attention, antialiasing step is missing in current version.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import kornia
|
7 |
+
from kornia.geometry.camera import pixel2cam
|
8 |
+
import numpy as np
|
9 |
+
from typing import List
|
10 |
+
from scipy.io import loadmat
|
11 |
+
from torch import nn
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
try:
|
15 |
+
import pytorch3d.ops
|
16 |
+
from pytorch3d.structures import Meshes
|
17 |
+
from pytorch3d.renderer import (
|
18 |
+
look_at_view_transform,
|
19 |
+
FoVPerspectiveCameras,
|
20 |
+
DirectionalLights,
|
21 |
+
RasterizationSettings,
|
22 |
+
MeshRenderer,
|
23 |
+
MeshRasterizer,
|
24 |
+
SoftPhongShader,
|
25 |
+
TexturesUV,
|
26 |
+
)
|
27 |
+
except:
|
28 |
+
traceback.print_exc()
|
29 |
+
# def ndc_projection(x=0.1, n=1.0, f=50.0):
|
30 |
+
# return np.array([[n/x, 0, 0, 0],
|
31 |
+
# [ 0, n/-x, 0, 0],
|
32 |
+
# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
33 |
+
# [ 0, 0, -1, 0]]).astype(np.float32)
|
34 |
+
|
35 |
+
class MeshRenderer(nn.Module):
|
36 |
+
def __init__(self,
|
37 |
+
rasterize_fov,
|
38 |
+
znear=0.1,
|
39 |
+
zfar=10,
|
40 |
+
rasterize_size=224,**args):
|
41 |
+
super(MeshRenderer, self).__init__()
|
42 |
+
|
43 |
+
# x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
|
44 |
+
# self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
|
45 |
+
# torch.diag(torch.tensor([1., -1, -1, 1])))
|
46 |
+
self.rasterize_size = rasterize_size
|
47 |
+
self.fov = rasterize_fov
|
48 |
+
self.znear = znear
|
49 |
+
self.zfar = zfar
|
50 |
+
|
51 |
+
self.rasterizer = None
|
52 |
+
|
53 |
+
def forward(self, vertex, tri, feat=None):
|
54 |
+
"""
|
55 |
+
Return:
|
56 |
+
mask -- torch.tensor, size (B, 1, H, W)
|
57 |
+
depth -- torch.tensor, size (B, 1, H, W)
|
58 |
+
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
vertex -- torch.tensor, size (B, N, 3)
|
62 |
+
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
|
63 |
+
feat(optional) -- torch.tensor, size (B, N ,C), features
|
64 |
+
"""
|
65 |
+
device = vertex.device
|
66 |
+
rsize = int(self.rasterize_size)
|
67 |
+
# ndc_proj = self.ndc_proj.to(device)
|
68 |
+
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
|
69 |
+
if vertex.shape[-1] == 3:
|
70 |
+
vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
|
71 |
+
vertex[..., 0] = -vertex[..., 0]
|
72 |
+
|
73 |
+
|
74 |
+
# vertex_ndc = vertex @ ndc_proj.t()
|
75 |
+
if self.rasterizer is None:
|
76 |
+
self.rasterizer = MeshRasterizer()
|
77 |
+
print("create rasterizer on device cuda:%d"%device.index)
|
78 |
+
|
79 |
+
# ranges = None
|
80 |
+
# if isinstance(tri, List) or len(tri.shape) == 3:
|
81 |
+
# vum = vertex_ndc.shape[1]
|
82 |
+
# fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
|
83 |
+
# fstartidx = torch.cumsum(fnum, dim=0) - fnum
|
84 |
+
# ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
|
85 |
+
# for i in range(tri.shape[0]):
|
86 |
+
# tri[i] = tri[i] + i*vum
|
87 |
+
# vertex_ndc = torch.cat(vertex_ndc, dim=0)
|
88 |
+
# tri = torch.cat(tri, dim=0)
|
89 |
+
|
90 |
+
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
|
91 |
+
tri = tri.type(torch.int32).contiguous()
|
92 |
+
|
93 |
+
# rasterize
|
94 |
+
cameras = FoVPerspectiveCameras(
|
95 |
+
device=device,
|
96 |
+
fov=self.fov,
|
97 |
+
znear=self.znear,
|
98 |
+
zfar=self.zfar,
|
99 |
+
)
|
100 |
+
|
101 |
+
raster_settings = RasterizationSettings(
|
102 |
+
image_size=rsize
|
103 |
+
)
|
104 |
+
|
105 |
+
# print(vertex.shape, tri.shape)
|
106 |
+
if tri.ndim == 2:
|
107 |
+
tri = tri.unsqueeze(0)
|
108 |
+
mesh = Meshes(vertex.contiguous()[...,:3], tri)
|
109 |
+
|
110 |
+
fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
|
111 |
+
rast_out = fragments.pix_to_face.squeeze(-1)
|
112 |
+
depth = fragments.zbuf
|
113 |
+
|
114 |
+
# render depth
|
115 |
+
depth = depth.permute(0, 3, 1, 2)
|
116 |
+
mask = (rast_out > 0).float().unsqueeze(1)
|
117 |
+
depth = mask * depth
|
118 |
+
|
119 |
+
|
120 |
+
image = None
|
121 |
+
if feat is not None:
|
122 |
+
attributes = feat.reshape(-1,3)[mesh.faces_packed()]
|
123 |
+
image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
|
124 |
+
fragments.bary_coords,
|
125 |
+
attributes)
|
126 |
+
# print(image.shape)
|
127 |
+
image = image.squeeze(-2).permute(0, 3, 1, 2)
|
128 |
+
image = mask * image
|
129 |
+
|
130 |
+
return mask, depth, image
|
131 |
+
|
docs/prepare_env/install_guide-zh.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 环境配置
|
2 |
+
[English Doc](./install_guide.md)
|
3 |
+
|
4 |
+
本文档陈述了搭建Real3D-Portrait Python环境的步骤,我们使用了Conda来管理依赖。
|
5 |
+
|
6 |
+
以下配置已在 A100/V100 + CUDA11.7 中进行了验证。
|
7 |
+
|
8 |
+
|
9 |
+
# 1. 安装CUDA
|
10 |
+
我们推荐安装CUDA `11.7`,其他CUDA版本(例如`10.2`、`12.x`)也可能有效。
|
11 |
+
|
12 |
+
# 2. 安装Python依赖
|
13 |
+
```
|
14 |
+
cd <Real3DPortraitRoot>
|
15 |
+
source <CondaRoot>/bin/activate
|
16 |
+
conda create -n real3dportrait python=3.9
|
17 |
+
conda activate real3dportrait
|
18 |
+
conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
|
19 |
+
|
20 |
+
# 我们推荐安装torch2.0.1+cuda11.7.
|
21 |
+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
22 |
+
|
23 |
+
# 从源代码安装,需要比较长的时间 (如果遇到各种time-out问题,建议使用代理)
|
24 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
25 |
+
|
26 |
+
# MMCV安装
|
27 |
+
pip install cython
|
28 |
+
pip install openmim==0.3.9
|
29 |
+
mim install mmcv==2.1.0 # 使用mim来加速mmcv安装
|
30 |
+
|
31 |
+
# 其他依赖项
|
32 |
+
pip install -r docs/prepare_env/requirements.txt -v
|
33 |
+
|
34 |
+
```
|
35 |
+
|
docs/prepare_env/install_guide.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prepare the Environment
|
2 |
+
[中文文档](./install_guide-zh.md)
|
3 |
+
|
4 |
+
This guide is about building a python environment for Real3D-Portrait with Conda.
|
5 |
+
|
6 |
+
The following installation process is verified in A100/V100 + CUDA11.7.
|
7 |
+
|
8 |
+
|
9 |
+
# 1. Install CUDA
|
10 |
+
We recommend to install CUDA `11.7` (which is verified in various types of GPUs), but other CUDA versions (such as `10.2`, `12.x`) may also work well.
|
11 |
+
|
12 |
+
# 2. Install Python Packages
|
13 |
+
```
|
14 |
+
cd <Real3DPortraitRoot>
|
15 |
+
source <CondaRoot>/bin/activate
|
16 |
+
conda create -n real3dportrait python=3.9
|
17 |
+
conda activate real3dportrait
|
18 |
+
conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
|
19 |
+
|
20 |
+
### We recommend torch2.0.1+cuda11.7.
|
21 |
+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
22 |
+
|
23 |
+
# Build from source, it may take a long time (Proxy is recommended if encountering the time-out problem)
|
24 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
25 |
+
|
26 |
+
# MMCV for some network structure
|
27 |
+
pip install cython
|
28 |
+
pip install openmim==0.3.9
|
29 |
+
mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
|
30 |
+
|
31 |
+
# other dependencies
|
32 |
+
pip install -r docs/prepare_env/requirements.txt -v
|
33 |
+
|
34 |
+
```
|
docs/prepare_env/requirements.txt
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Cython
|
2 |
+
numpy # ==1.23.0
|
3 |
+
numba==0.56.4
|
4 |
+
pandas
|
5 |
+
transformers
|
6 |
+
scipy==1.11.1 # required by cal_fid. https://github.com/mseitzer/pytorch-fid/issues/103
|
7 |
+
scikit-learn
|
8 |
+
scikit-image
|
9 |
+
# tensorflow # you can flexible it, this is gpu version
|
10 |
+
tensorboard
|
11 |
+
tensorboardX
|
12 |
+
python_speech_features
|
13 |
+
resampy
|
14 |
+
opencv_python
|
15 |
+
face_alignment
|
16 |
+
matplotlib
|
17 |
+
configargparse
|
18 |
+
librosa==0.9.2
|
19 |
+
praat-parselmouth # ==0.4.3
|
20 |
+
trimesh
|
21 |
+
kornia==0.5.0
|
22 |
+
PyMCubes
|
23 |
+
lpips
|
24 |
+
setuptools # ==59.5.0
|
25 |
+
ffmpeg-python
|
26 |
+
moviepy
|
27 |
+
dearpygui
|
28 |
+
ninja
|
29 |
+
# pyaudio # for extract esperanto
|
30 |
+
mediapipe
|
31 |
+
protobuf
|
32 |
+
decord
|
33 |
+
soundfile
|
34 |
+
pillow
|
35 |
+
# torch # it's better to install torch with conda
|
36 |
+
av
|
37 |
+
timm
|
38 |
+
pretrainedmodels
|
39 |
+
faiss-cpu # for fast nearest camera pose retriveal
|
40 |
+
einops
|
41 |
+
# mmcv # use mim install is faster
|
42 |
+
|
43 |
+
# conditional flow matching
|
44 |
+
beartype
|
45 |
+
torchode
|
46 |
+
torchdiffeq
|
47 |
+
|
48 |
+
# tts
|
49 |
+
cython
|
50 |
+
textgrid
|
51 |
+
pyloudnorm
|
52 |
+
websocket-client
|
53 |
+
pyworld==0.2.1rc0
|
54 |
+
pypinyin==0.42.0
|
55 |
+
webrtcvad
|
56 |
+
torchshow
|
57 |
+
|
58 |
+
# cal spk sim
|
59 |
+
s3prl
|
60 |
+
fire
|
61 |
+
|
62 |
+
# cal LMD
|
63 |
+
dlib
|
64 |
+
|
65 |
+
# debug
|
66 |
+
ipykernel
|
67 |
+
|
68 |
+
# lama
|
69 |
+
hydra-core
|
70 |
+
pytorch_lightning
|
71 |
+
setproctitle
|
72 |
+
|
73 |
+
# Gradio GUI
|
74 |
+
httpx==0.23.3
|
75 |
+
gradio==4.16.0
|
inference/app_real3dportrait.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import argparse
|
3 |
+
import gradio as gr
|
4 |
+
from inference.real3d_infer import GeneFace2Infer
|
5 |
+
from utils.commons.hparams import hparams
|
6 |
+
|
7 |
+
class Inferer(GeneFace2Infer):
|
8 |
+
def infer_once_args(self, *args, **kargs):
|
9 |
+
assert len(kargs) == 0
|
10 |
+
keys = [
|
11 |
+
'src_image_name',
|
12 |
+
'drv_audio_name',
|
13 |
+
'drv_pose_name',
|
14 |
+
'bg_image_name',
|
15 |
+
'blink_mode',
|
16 |
+
'temperature',
|
17 |
+
'mouth_amp',
|
18 |
+
'out_mode',
|
19 |
+
'map_to_init_pose',
|
20 |
+
'hold_eye_opened',
|
21 |
+
'head_torso_threshold',
|
22 |
+
'a2m_ckpt',
|
23 |
+
'head_ckpt',
|
24 |
+
'torso_ckpt',
|
25 |
+
]
|
26 |
+
inp = {}
|
27 |
+
out_name = None
|
28 |
+
info = ""
|
29 |
+
|
30 |
+
try: # try to catch errors and jump to return
|
31 |
+
for key_index in range(len(keys)):
|
32 |
+
key = keys[key_index]
|
33 |
+
inp[key] = args[key_index]
|
34 |
+
if '_name' in key:
|
35 |
+
inp[key] = inp[key] if inp[key] is not None else ''
|
36 |
+
|
37 |
+
if inp['src_image_name'] == '':
|
38 |
+
info = "Input Error: Source image is REQUIRED!"
|
39 |
+
raise ValueError
|
40 |
+
if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
|
41 |
+
info = "Input Error: At least one of driving audio or video is REQUIRED!"
|
42 |
+
raise ValueError
|
43 |
+
|
44 |
+
|
45 |
+
if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
|
46 |
+
inp['drv_audio_name'] = inp['drv_pose_name']
|
47 |
+
print("No audio input, we use driving pose video for video driving")
|
48 |
+
|
49 |
+
if inp['drv_pose_name'] == '':
|
50 |
+
inp['drv_pose_name'] = 'static'
|
51 |
+
|
52 |
+
reload_flag = False
|
53 |
+
if inp['a2m_ckpt'] != self.audio2secc_dir:
|
54 |
+
print("Changes of a2m_ckpt detected, reloading model")
|
55 |
+
reload_flag = True
|
56 |
+
if inp['head_ckpt'] != self.head_model_dir:
|
57 |
+
print("Changes of head_ckpt detected, reloading model")
|
58 |
+
reload_flag = True
|
59 |
+
if inp['torso_ckpt'] != self.torso_model_dir:
|
60 |
+
print("Changes of torso_ckpt detected, reloading model")
|
61 |
+
reload_flag = True
|
62 |
+
|
63 |
+
inp['out_name'] = ''
|
64 |
+
inp['seed'] = 42
|
65 |
+
|
66 |
+
print(f"infer inputs : {inp}")
|
67 |
+
if self.secc2video_hparams['htbsr_head_threshold'] != inp['head_torso_threshold']:
|
68 |
+
print("Changes of head_torso_threshold detected, reloading model")
|
69 |
+
reload_flag = True
|
70 |
+
|
71 |
+
try:
|
72 |
+
if reload_flag:
|
73 |
+
self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
|
74 |
+
except Exception as e:
|
75 |
+
content = f"{e}"
|
76 |
+
info = f"Reload ERROR: {content}"
|
77 |
+
raise ValueError
|
78 |
+
try:
|
79 |
+
out_name = self.infer_once(inp)
|
80 |
+
except Exception as e:
|
81 |
+
content = f"{e}"
|
82 |
+
info = f"Inference ERROR: {content}"
|
83 |
+
raise ValueError
|
84 |
+
except Exception as e:
|
85 |
+
if info == "": # unexpected errors
|
86 |
+
content = f"{e}"
|
87 |
+
info = f"WebUI ERROR: {content}"
|
88 |
+
|
89 |
+
# output part
|
90 |
+
if len(info) > 0 : # there is errors
|
91 |
+
print(info)
|
92 |
+
info_gr = gr.update(visible=True, value=info)
|
93 |
+
else: # no errors
|
94 |
+
info_gr = gr.update(visible=False, value=info)
|
95 |
+
if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
|
96 |
+
print(f"Succefully generated in {out_name}")
|
97 |
+
video_gr = gr.update(visible=True, value=out_name)
|
98 |
+
else:
|
99 |
+
print(f"Failed to generate")
|
100 |
+
video_gr = gr.update(visible=True, value=out_name)
|
101 |
+
|
102 |
+
return video_gr, info_gr
|
103 |
+
|
104 |
+
def toggle_audio_file(choice):
|
105 |
+
if choice == False:
|
106 |
+
return gr.update(visible=True), gr.update(visible=False)
|
107 |
+
else:
|
108 |
+
return gr.update(visible=False), gr.update(visible=True)
|
109 |
+
|
110 |
+
def ref_video_fn(path_of_ref_video):
|
111 |
+
if path_of_ref_video is not None:
|
112 |
+
return gr.update(value=True)
|
113 |
+
else:
|
114 |
+
return gr.update(value=False)
|
115 |
+
|
116 |
+
def real3dportrait_demo(
|
117 |
+
audio2secc_dir,
|
118 |
+
head_model_dir,
|
119 |
+
torso_model_dir,
|
120 |
+
device = 'cuda',
|
121 |
+
warpfn = None,
|
122 |
+
):
|
123 |
+
|
124 |
+
sep_line = "-" * 40
|
125 |
+
|
126 |
+
infer_obj = Inferer(
|
127 |
+
audio2secc_dir=audio2secc_dir,
|
128 |
+
head_model_dir=head_model_dir,
|
129 |
+
torso_model_dir=torso_model_dir,
|
130 |
+
device=device,
|
131 |
+
)
|
132 |
+
|
133 |
+
print(sep_line)
|
134 |
+
print("Model loading is finished.")
|
135 |
+
print(sep_line)
|
136 |
+
with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
|
137 |
+
gr.Markdown("\
|
138 |
+
<div align='center'> <h2> Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight) </span> </h2> \
|
139 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://arxiv.org/pdf/2401.08503.pdf'>Arxiv</a> \
|
140 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://real3dportrait.github.io/'>Homepage</a> \
|
141 |
+
<a style='font-size:18px;color: #a0a0a0' href='https://baidu.com'> Github </div>")
|
142 |
+
|
143 |
+
sources = None
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column(variant='panel'):
|
146 |
+
with gr.Tabs(elem_id="source_image"):
|
147 |
+
with gr.TabItem('Upload image'):
|
148 |
+
with gr.Row():
|
149 |
+
src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
|
150 |
+
with gr.Tabs(elem_id="driven_audio"):
|
151 |
+
with gr.TabItem('Upload audio'):
|
152 |
+
with gr.Column(variant='panel'):
|
153 |
+
drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
|
154 |
+
with gr.Tabs(elem_id="driven_pose"):
|
155 |
+
with gr.TabItem('Upload video'):
|
156 |
+
with gr.Column(variant='panel'):
|
157 |
+
drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
|
158 |
+
with gr.Tabs(elem_id="bg_image"):
|
159 |
+
with gr.TabItem('Upload image'):
|
160 |
+
with gr.Row():
|
161 |
+
bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
|
162 |
+
|
163 |
+
|
164 |
+
with gr.Column(variant='panel'):
|
165 |
+
with gr.Tabs(elem_id="checkbox"):
|
166 |
+
with gr.TabItem('General Settings'):
|
167 |
+
with gr.Column(variant='panel'):
|
168 |
+
|
169 |
+
blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
|
170 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
|
171 |
+
mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
|
172 |
+
out_mode = gr.Radio(['final', 'concat_debug'], value='final', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
|
173 |
+
map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose")
|
174 |
+
hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
|
175 |
+
head_torso_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="head torso threshold", value=0.7, info='make it higher if you find ghosting around hair of output, default to be 0.7',)
|
176 |
+
|
177 |
+
submit = gr.Button('Generate', elem_id="generate", variant='primary')
|
178 |
+
|
179 |
+
with gr.Tabs(elem_id="genearted_video"):
|
180 |
+
info_box = gr.Textbox(label="Error", interactive=False, visible=False)
|
181 |
+
gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
|
182 |
+
with gr.Column(variant='panel'):
|
183 |
+
with gr.Tabs(elem_id="checkbox"):
|
184 |
+
with gr.TabItem('Checkpoints'):
|
185 |
+
with gr.Column(variant='panel'):
|
186 |
+
ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
|
187 |
+
audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
|
188 |
+
head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
|
189 |
+
torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
|
190 |
+
# audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
|
191 |
+
# head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
|
192 |
+
# torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
|
193 |
+
|
194 |
+
|
195 |
+
fn = infer_obj.infer_once_args
|
196 |
+
if warpfn:
|
197 |
+
fn = warpfn(fn)
|
198 |
+
submit.click(
|
199 |
+
fn=fn,
|
200 |
+
inputs=[
|
201 |
+
src_image_name,
|
202 |
+
drv_audio_name,
|
203 |
+
drv_pose_name,
|
204 |
+
bg_image_name,
|
205 |
+
blink_mode,
|
206 |
+
temperature,
|
207 |
+
mouth_amp,
|
208 |
+
out_mode,
|
209 |
+
map_to_init_pose,
|
210 |
+
hold_eye_opened,
|
211 |
+
head_torso_threshold,
|
212 |
+
audio2secc_dir,
|
213 |
+
head_model_dir,
|
214 |
+
torso_model_dir,
|
215 |
+
],
|
216 |
+
outputs=[
|
217 |
+
gen_video,
|
218 |
+
info_box,
|
219 |
+
],
|
220 |
+
)
|
221 |
+
|
222 |
+
print(sep_line)
|
223 |
+
print("Gradio page is constructed.")
|
224 |
+
print(sep_line)
|
225 |
+
|
226 |
+
return real3dportrait_interface
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
parser = argparse.ArgumentParser()
|
230 |
+
parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240126_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
|
231 |
+
parser.add_argument("--head_ckpt", type=str, default='')
|
232 |
+
parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240126_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
|
233 |
+
parser.add_argument("--port", type=int, default=None)
|
234 |
+
args = parser.parse_args()
|
235 |
+
demo = real3dportrait_demo(
|
236 |
+
audio2secc_dir=args.a2m_ckpt,
|
237 |
+
head_model_dir=args.head_ckpt,
|
238 |
+
torso_model_dir=args.torso_ckpt,
|
239 |
+
device='cuda:0',
|
240 |
+
warpfn=None,
|
241 |
+
)
|
242 |
+
demo.queue()
|
243 |
+
demo.launch(server_port=args.port)
|
244 |
+
|
inference/edit_secc.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from utils.commons.image_utils import dilate, erode
|
4 |
+
from sklearn.neighbors import NearestNeighbors
|
5 |
+
import copy
|
6 |
+
import numpy as np
|
7 |
+
from utils.commons.meters import Timer
|
8 |
+
|
9 |
+
def hold_eye_opened_for_secc(img):
|
10 |
+
img = img.permute(1,2,0).cpu().numpy()
|
11 |
+
img = ((img +1)/2*255).astype(np.uint)
|
12 |
+
face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
|
13 |
+
face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
14 |
+
h,w = face_mask.shape
|
15 |
+
# get face and eye mask
|
16 |
+
left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
17 |
+
right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
18 |
+
left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
|
19 |
+
right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
|
20 |
+
eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
|
21 |
+
coarse_eye_mask = (~ face_mask) & eye_prior_reigon
|
22 |
+
coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
23 |
+
|
24 |
+
opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
|
25 |
+
opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
|
26 |
+
coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
|
27 |
+
|
28 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
|
29 |
+
dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
|
30 |
+
# print(dists.max())
|
31 |
+
non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
|
32 |
+
non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
|
33 |
+
opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
|
34 |
+
opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
|
35 |
+
|
36 |
+
img[opened_eye_mask] = 0
|
37 |
+
return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
|
38 |
+
|
39 |
+
|
40 |
+
# def hold_eye_opened_for_secc(img):
|
41 |
+
# img = copy.copy(img)
|
42 |
+
# eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
|
43 |
+
# eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
|
44 |
+
# img[eye_mask] = -1
|
45 |
+
# return img
|
46 |
+
|
47 |
+
def blink_eye_for_secc(img, close_eye_percent=0.5):
|
48 |
+
"""
|
49 |
+
secc_img: [3,h,w], tensor, -1~1
|
50 |
+
"""
|
51 |
+
img = img.permute(1,2,0).cpu().numpy()
|
52 |
+
img = ((img +1)/2*255).astype(np.uint)
|
53 |
+
assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
|
54 |
+
if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
|
55 |
+
img = copy.deepcopy(img)
|
56 |
+
face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
|
57 |
+
h,w = face_mask.shape
|
58 |
+
|
59 |
+
# get face and eye mask
|
60 |
+
left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
61 |
+
right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
62 |
+
left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
|
63 |
+
right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
|
64 |
+
eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
|
65 |
+
coarse_eye_mask = (~ face_mask) & eye_prior_reigon
|
66 |
+
coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
|
67 |
+
coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
|
68 |
+
coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
69 |
+
min_h = coarse_eye_xys[:, 0].min()
|
70 |
+
max_h = coarse_eye_xys[:, 0].max()
|
71 |
+
coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
72 |
+
left_min_w = coarse_left_eye_xys[:, 1].min()
|
73 |
+
left_max_w = coarse_left_eye_xys[:, 1].max()
|
74 |
+
coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
75 |
+
right_min_w = coarse_right_eye_xys[:, 1].min()
|
76 |
+
right_max_w = coarse_right_eye_xys[:, 1].max()
|
77 |
+
|
78 |
+
# 尽力较少需要考虑的face_xyz,以降低KNN的损耗
|
79 |
+
left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
80 |
+
more_room = 4 # 过小会导致一些问题
|
81 |
+
left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
|
82 |
+
right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
|
83 |
+
right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
|
84 |
+
eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
|
85 |
+
|
86 |
+
around_eye_face_mask = face_mask & eye_prior_reigon
|
87 |
+
face_mask = around_eye_face_mask
|
88 |
+
face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
89 |
+
|
90 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
|
91 |
+
dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
|
92 |
+
face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
|
93 |
+
face_pixs = face_pixs.reshape([-1])
|
94 |
+
face_xys_to_erode = face_xys[~face_pixs]
|
95 |
+
face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
|
96 |
+
eye_mask = (~ face_mask) & eye_prior_reigon
|
97 |
+
|
98 |
+
h_grid = np.mgrid[0:h, 0:w][0]
|
99 |
+
eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
|
100 |
+
eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
|
101 |
+
|
102 |
+
tmp_h_grid = h_grid.copy()
|
103 |
+
tmp_h_grid[~eye_mask] = 0
|
104 |
+
eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
|
105 |
+
tmp_h_grid = h_grid.copy()
|
106 |
+
tmp_h_grid[~eye_mask] = 99999
|
107 |
+
eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
|
108 |
+
tmp_h_grid = h_grid.copy()
|
109 |
+
tmp_h_grid[~eye_mask] = -99999
|
110 |
+
eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
|
111 |
+
|
112 |
+
eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
|
113 |
+
eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
|
114 |
+
|
115 |
+
tmp_h_grid = h_grid.copy()
|
116 |
+
tmp_h_grid[~eye_mask] = 99999
|
117 |
+
upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
|
118 |
+
tmp_h_grid = h_grid.copy()
|
119 |
+
tmp_h_grid[~eye_mask] = -99999
|
120 |
+
lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
|
121 |
+
eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
|
122 |
+
|
123 |
+
face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
|
124 |
+
eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
|
125 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
|
126 |
+
distances, indices = nbrs.kneighbors(eye_blink_xys)
|
127 |
+
bg_fg_xys = face_xys[indices[:, 0]]
|
128 |
+
img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
129 |
+
return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
import imageio
|
134 |
+
import tqdm
|
135 |
+
img = cv2.imread("assets/cano_secc.png")
|
136 |
+
img = img / 127.5 - 1
|
137 |
+
img = torch.FloatTensor(img).permute(2, 0, 1)
|
138 |
+
fps = 25
|
139 |
+
writer = imageio.get_writer('demo_blink.mp4', fps=fps)
|
140 |
+
|
141 |
+
for i in tqdm.trange(33):
|
142 |
+
blink_percent = 0.03 * i
|
143 |
+
with Timer("Blink", True):
|
144 |
+
out_img = blink_eye_for_secc(img, blink_percent)
|
145 |
+
out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
|
146 |
+
writer.append_data(out_img)
|
147 |
+
writer.close()
|
inference/infer_utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import importlib
|
7 |
+
import tqdm
|
8 |
+
import copy
|
9 |
+
import cv2
|
10 |
+
from scipy.spatial.transform import Rotation
|
11 |
+
|
12 |
+
|
13 |
+
def load_img_to_512_hwc_array(img_name):
|
14 |
+
img = cv2.imread(img_name)
|
15 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
16 |
+
img = cv2.resize(img, (512, 512))
|
17 |
+
return img
|
18 |
+
|
19 |
+
def load_img_to_normalized_512_bchw_tensor(img_name):
|
20 |
+
img = load_img_to_512_hwc_array(img_name)
|
21 |
+
img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
|
22 |
+
return img
|
23 |
+
|
24 |
+
def mirror_index(index, len_seq):
|
25 |
+
"""
|
26 |
+
get mirror index when indexing a sequence and the index is larger than len_pose
|
27 |
+
args:
|
28 |
+
index: int
|
29 |
+
len_pose: int
|
30 |
+
return:
|
31 |
+
mirror_index: int
|
32 |
+
"""
|
33 |
+
turn = index // len_seq
|
34 |
+
res = index % len_seq
|
35 |
+
if turn % 2 == 0:
|
36 |
+
return res # forward indexing
|
37 |
+
else:
|
38 |
+
return len_seq - res - 1 # reverse indexing
|
39 |
+
|
40 |
+
def smooth_camera_sequence(camera, kernel_size=7):
|
41 |
+
"""
|
42 |
+
smooth the camera trajectory (i.e., rotation & translation)...
|
43 |
+
args:
|
44 |
+
camera: [N, 25] or [N, 16]. np.ndarray
|
45 |
+
kernel_size: int
|
46 |
+
return:
|
47 |
+
smoothed_camera: [N, 25] or [N, 16]. np.ndarray
|
48 |
+
"""
|
49 |
+
# poses: [N, 25], numpy array
|
50 |
+
N = camera.shape[0]
|
51 |
+
K = kernel_size // 2
|
52 |
+
poses = camera[:, :16].reshape([-1, 4, 4]).copy()
|
53 |
+
trans = poses[:, :3, 3].copy() # [N, 3]
|
54 |
+
rots = poses[:, :3, :3].copy() # [N, 3, 3]
|
55 |
+
|
56 |
+
for i in range(N):
|
57 |
+
start = max(0, i - K)
|
58 |
+
end = min(N, i + K + 1)
|
59 |
+
poses[i, :3, 3] = trans[start:end].mean(0)
|
60 |
+
try:
|
61 |
+
poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
|
62 |
+
except:
|
63 |
+
if i == 0:
|
64 |
+
poses[i, :3, :3] = rots[i]
|
65 |
+
else:
|
66 |
+
poses[i, :3, :3] = poses[i-1, :3, :3]
|
67 |
+
poses = poses.reshape([-1, 16])
|
68 |
+
camera[:, :16] = poses
|
69 |
+
return camera
|
70 |
+
|
71 |
+
def smooth_features_xd(in_tensor, kernel_size=7):
|
72 |
+
"""
|
73 |
+
smooth the feature maps
|
74 |
+
args:
|
75 |
+
in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
|
76 |
+
kernel_size: int
|
77 |
+
return:
|
78 |
+
out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
|
79 |
+
"""
|
80 |
+
t = in_tensor.shape[0]
|
81 |
+
ndim = in_tensor.ndim
|
82 |
+
pad = (kernel_size- 1)//2
|
83 |
+
in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
|
84 |
+
if ndim == 2: # tc
|
85 |
+
_,c = in_tensor.shape
|
86 |
+
in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
|
87 |
+
elif ndim == 4: # tchw
|
88 |
+
_,c,h,w = in_tensor.shape
|
89 |
+
in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
|
90 |
+
elif ndim == 5: # tcchw, like deformation
|
91 |
+
_,c1,c2, h,w = in_tensor.shape
|
92 |
+
in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
|
93 |
+
else: raise NotImplementedError()
|
94 |
+
avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
|
95 |
+
out_tensor = F.conv1d(in_tensor, avg_kernel)
|
96 |
+
if ndim == 2: # tc
|
97 |
+
return out_tensor.reshape([c,t]).permute(1,0)
|
98 |
+
elif ndim == 4: # tchw
|
99 |
+
return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
|
100 |
+
elif ndim == 5: # tcchw, like deformation
|
101 |
+
return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
|
102 |
+
|
103 |
+
|
104 |
+
def extract_audio_motion_from_ref_video(video_name):
|
105 |
+
def save_wav16k(audio_name):
|
106 |
+
supported_types = ('.wav', '.mp3', '.mp4', '.avi')
|
107 |
+
assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
|
108 |
+
wav16k_name = audio_name[:-4] + '_16k.wav'
|
109 |
+
extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
|
110 |
+
os.system(extract_wav_cmd)
|
111 |
+
print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
|
112 |
+
return wav16k_name
|
113 |
+
|
114 |
+
def get_f0( wav16k_name):
|
115 |
+
from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_from_fname,extract_f0_from_wav_and_mel
|
116 |
+
wav, mel = extract_mel_from_fname(wav16k_name)
|
117 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
118 |
+
f0 = f0.reshape([-1,1])
|
119 |
+
f0 = torch.tensor(f0)
|
120 |
+
return f0
|
121 |
+
|
122 |
+
def get_hubert(wav16k_name):
|
123 |
+
from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
|
124 |
+
hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
|
125 |
+
len_mel = hubert.shape[0]
|
126 |
+
x_multiply = 8
|
127 |
+
if len_mel % x_multiply == 0:
|
128 |
+
num_to_pad = 0
|
129 |
+
else:
|
130 |
+
num_to_pad = x_multiply - len_mel % x_multiply
|
131 |
+
hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
|
132 |
+
hubert = torch.tensor(hubert)
|
133 |
+
return hubert
|
134 |
+
|
135 |
+
def get_exp(video_name):
|
136 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
137 |
+
drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
|
138 |
+
exp = torch.tensor(drv_motion_coeff_dict['exp'])
|
139 |
+
return exp
|
140 |
+
|
141 |
+
wav16k_name = save_wav16k(video_name)
|
142 |
+
f0 = get_f0(wav16k_name)
|
143 |
+
hubert = get_hubert(wav16k_name)
|
144 |
+
os.system(f"rm {wav16k_name}")
|
145 |
+
exp = get_exp(video_name)
|
146 |
+
target_length = min(len(exp), len(hubert)//2, len(f0)//2)
|
147 |
+
exp = exp[:target_length]
|
148 |
+
f0 = f0[:target_length*2]
|
149 |
+
hubert = hubert[:target_length*2]
|
150 |
+
return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == '__main__':
|
154 |
+
extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
|
inference/real3d_infer.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchshow as ts
|
5 |
+
import librosa
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import importlib
|
10 |
+
import tqdm
|
11 |
+
import copy
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
# common utils
|
15 |
+
from utils.commons.hparams import hparams, set_hparams
|
16 |
+
from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
|
17 |
+
from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
|
18 |
+
# 3DMM-related utils
|
19 |
+
from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
|
20 |
+
from data_util.face3d_helper import Face3DHelper
|
21 |
+
from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
|
22 |
+
from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
|
23 |
+
from deep_3drecon.secc_renderer import SECC_Renderer
|
24 |
+
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
|
25 |
+
# Face Parsing
|
26 |
+
from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
|
27 |
+
from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
|
28 |
+
# other inference utils
|
29 |
+
from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
|
30 |
+
from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
|
31 |
+
from Real3DPortrait.inference.edit_secc import blink_eye_for_secc
|
32 |
+
|
33 |
+
|
34 |
+
def read_first_frame_from_a_video(vid_name):
|
35 |
+
frames = []
|
36 |
+
cap = cv2.VideoCapture(vid_name)
|
37 |
+
ret, frame_bgr = cap.read()
|
38 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
39 |
+
return frame_rgb
|
40 |
+
|
41 |
+
def analyze_weights_img(gen_output):
|
42 |
+
img_raw = gen_output['image_raw']
|
43 |
+
mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
|
44 |
+
mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
|
45 |
+
mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
|
46 |
+
mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
|
47 |
+
mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
|
48 |
+
|
49 |
+
img_raw_005_to_03 = img_raw.clone()
|
50 |
+
img_raw_005_to_03[~mask_005_to_03] = -1
|
51 |
+
img_raw_005_to_05 = img_raw.clone()
|
52 |
+
img_raw_005_to_05[~mask_005_to_05] = -1
|
53 |
+
img_raw_005_to_07 = img_raw.clone()
|
54 |
+
img_raw_005_to_07[~mask_005_to_07] = -1
|
55 |
+
img_raw_005_to_09 = img_raw.clone()
|
56 |
+
img_raw_005_to_09[~mask_005_to_09] = -1
|
57 |
+
img_raw_005_to_10 = img_raw.clone()
|
58 |
+
img_raw_005_to_10[~mask_005_to_10] = -1
|
59 |
+
ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
|
60 |
+
|
61 |
+
class GeneFace2Infer:
|
62 |
+
def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
|
63 |
+
if device is None:
|
64 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
65 |
+
self.device = device
|
66 |
+
self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
|
67 |
+
self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
|
68 |
+
self.audio2secc_model.to(device).eval()
|
69 |
+
self.secc2video_model.to(device).eval()
|
70 |
+
self.seg_model = MediapipeSegmenter()
|
71 |
+
self.secc_renderer = SECC_Renderer(512)
|
72 |
+
self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
|
73 |
+
self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
|
74 |
+
|
75 |
+
def load_audio2secc(self, audio2secc_dir):
|
76 |
+
config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
|
77 |
+
set_hparams(f"{config_name}", print_hparams=False)
|
78 |
+
self.audio2secc_dir = audio2secc_dir
|
79 |
+
self.audio2secc_hparams = copy.deepcopy(hparams)
|
80 |
+
from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
|
81 |
+
if self.audio2secc_hparams['audio_type'] == 'hubert':
|
82 |
+
audio_in_dim = 1024
|
83 |
+
elif self.audio2secc_hparams['audio_type'] == 'mfcc':
|
84 |
+
audio_in_dim = 13
|
85 |
+
|
86 |
+
if 'icl' in hparams['task_cls']:
|
87 |
+
self.use_icl_audio2motion = True
|
88 |
+
model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
|
89 |
+
else:
|
90 |
+
self.use_icl_audio2motion = False
|
91 |
+
if hparams.get("use_pitch", False) is True:
|
92 |
+
model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
|
93 |
+
else:
|
94 |
+
model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
|
95 |
+
load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
|
96 |
+
return model
|
97 |
+
|
98 |
+
def load_secc2video(self, head_model_dir, torso_model_dir, inp):
|
99 |
+
if inp is None:
|
100 |
+
inp = {}
|
101 |
+
self.head_model_dir = head_model_dir
|
102 |
+
self.torso_model_dir = torso_model_dir
|
103 |
+
if torso_model_dir != '':
|
104 |
+
if torso_model_dir.endswith(".ckpt"):
|
105 |
+
set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
|
106 |
+
else:
|
107 |
+
set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
|
108 |
+
if inp.get('head_torso_threshold', None) is not None:
|
109 |
+
hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
|
110 |
+
self.secc2video_hparams = copy.deepcopy(hparams)
|
111 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
|
112 |
+
model = OSAvatarSECC_Img2plane_Torso()
|
113 |
+
load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=False)
|
114 |
+
if head_model_dir != '':
|
115 |
+
print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
|
116 |
+
else:
|
117 |
+
from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
|
118 |
+
if head_model_dir.endswith(".ckpt"):
|
119 |
+
set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
|
120 |
+
else:
|
121 |
+
set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
|
122 |
+
if inp.get('head_torso_threshold', None) is not None:
|
123 |
+
hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
|
124 |
+
self.secc2video_hparams = copy.deepcopy(hparams)
|
125 |
+
model = OSAvatarSECC_Img2plane()
|
126 |
+
load_ckpt(model, f"{head_model_dir}", model_name='model', strict=False)
|
127 |
+
return model
|
128 |
+
|
129 |
+
def infer_once(self, inp):
|
130 |
+
self.inp = inp
|
131 |
+
samples = self.prepare_batch_from_inp(inp)
|
132 |
+
seed = inp['seed'] if inp['seed'] is not None else int(time.time())
|
133 |
+
random.seed(seed)
|
134 |
+
torch.manual_seed(seed)
|
135 |
+
np.random.seed(seed)
|
136 |
+
out_name = self.forward_system(samples, inp)
|
137 |
+
return out_name
|
138 |
+
|
139 |
+
def prepare_batch_from_inp(self, inp):
|
140 |
+
"""
|
141 |
+
:param inp: {'audio_source_name': (str)}
|
142 |
+
:return: a dict that contains the condition feature of NeRF
|
143 |
+
"""
|
144 |
+
sample = {}
|
145 |
+
# Process Driving Motion
|
146 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
147 |
+
self.save_wav16k(inp['drv_audio_name'])
|
148 |
+
if self.audio2secc_hparams['audio_type'] == 'hubert':
|
149 |
+
hubert = self.get_hubert(self.wav16k_name)
|
150 |
+
elif self.audio2secc_hparams['audio_type'] == 'mfcc':
|
151 |
+
hubert = self.get_mfcc(self.wav16k_name) / 100
|
152 |
+
|
153 |
+
f0 = self.get_f0(self.wav16k_name)
|
154 |
+
if f0.shape[0] > len(hubert):
|
155 |
+
f0 = f0[:len(hubert)]
|
156 |
+
else:
|
157 |
+
num_to_pad = len(hubert) - len(f0)
|
158 |
+
f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
|
159 |
+
t_x = hubert.shape[0]
|
160 |
+
x_mask = torch.ones([1, t_x]).float() # mask for audio frames
|
161 |
+
y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
|
162 |
+
sample.update({
|
163 |
+
'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
|
164 |
+
'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
|
165 |
+
'x_mask': x_mask.cuda(),
|
166 |
+
'y_mask': y_mask.cuda(),
|
167 |
+
})
|
168 |
+
sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
|
169 |
+
sample['audio'] = sample['hubert']
|
170 |
+
sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
|
171 |
+
sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
|
172 |
+
elif inp['drv_audio_name'][-4:] in ['.mp4']:
|
173 |
+
drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
|
174 |
+
drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
|
175 |
+
t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
|
176 |
+
self.drv_motion_coeff_dict = drv_motion_coeff_dict
|
177 |
+
elif inp['drv_audio_name'][-4:] in ['.npy']:
|
178 |
+
drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
|
179 |
+
drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
|
180 |
+
t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
|
181 |
+
self.drv_motion_coeff_dict = drv_motion_coeff_dict
|
182 |
+
|
183 |
+
# Face Parsing
|
184 |
+
image_name = inp['src_image_name']
|
185 |
+
if image_name.endswith(".mp4"):
|
186 |
+
img = read_first_frame_from_a_video(image_name)
|
187 |
+
image_name = inp['src_image_name'] = image_name[:-4] + '.png'
|
188 |
+
cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
189 |
+
sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
|
190 |
+
img = load_img_to_512_hwc_array(image_name)
|
191 |
+
segmap = self.seg_model._cal_seg_map(img)
|
192 |
+
sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
|
193 |
+
head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
|
194 |
+
sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
195 |
+
inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
|
196 |
+
sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
197 |
+
|
198 |
+
if inp['bg_image_name'] == '':
|
199 |
+
bg_img = extract_background([img], [segmap], 'knn')
|
200 |
+
else:
|
201 |
+
bg_img = cv2.imread(inp['bg_image_name'])
|
202 |
+
bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
|
203 |
+
bg_img = cv2.resize(bg_img, (512,512))
|
204 |
+
sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
|
205 |
+
|
206 |
+
# 3DMM, get identity code and camera pose
|
207 |
+
coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
|
208 |
+
assert coeff_dict is not None
|
209 |
+
src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
|
210 |
+
src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
|
211 |
+
src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
|
212 |
+
src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
|
213 |
+
sample['id'] = src_id.repeat([t_x//2,1])
|
214 |
+
|
215 |
+
# get the src_kp for torso model
|
216 |
+
src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
|
217 |
+
src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
|
218 |
+
sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
|
219 |
+
|
220 |
+
# get camera pose file
|
221 |
+
# random.seed(time.time())
|
222 |
+
inp['drv_pose_name'] = inp['drv_pose_name']
|
223 |
+
print(f"| To extract pose from {inp['drv_pose_name']}")
|
224 |
+
|
225 |
+
# extract camera pose
|
226 |
+
if inp['drv_pose_name'] == 'static':
|
227 |
+
sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
|
228 |
+
sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
|
229 |
+
else: # from file
|
230 |
+
if inp['drv_pose_name'].endswith('.mp4'):
|
231 |
+
# extract coeff from video
|
232 |
+
drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
|
233 |
+
else:
|
234 |
+
# load from npy
|
235 |
+
drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
|
236 |
+
print(f"| Extracted pose from {inp['drv_pose_name']}")
|
237 |
+
eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
|
238 |
+
trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
|
239 |
+
len_pose = len(eulers)
|
240 |
+
index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
|
241 |
+
sample['euler'] = eulers[index_lst]
|
242 |
+
sample['trans'] = trans[index_lst]
|
243 |
+
|
244 |
+
# fix the z axis
|
245 |
+
sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
|
246 |
+
|
247 |
+
# mapping to the init pose
|
248 |
+
if inp.get("map_to_init_pose", 'False') == 'True':
|
249 |
+
diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
|
250 |
+
sample['euler'] = sample['euler'] + diff_euler
|
251 |
+
diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
|
252 |
+
sample['trans'] = sample['trans'] + diff_trans
|
253 |
+
|
254 |
+
# prepare camera
|
255 |
+
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
|
256 |
+
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
|
257 |
+
# smooth camera
|
258 |
+
camera_smo_ksize = 7
|
259 |
+
camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
|
260 |
+
camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
|
261 |
+
camera = torch.tensor(camera).cuda().float()
|
262 |
+
sample['camera'] = camera
|
263 |
+
|
264 |
+
return sample
|
265 |
+
|
266 |
+
@torch.no_grad()
|
267 |
+
def get_hubert(self, wav16k_name):
|
268 |
+
from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
|
269 |
+
hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
|
270 |
+
len_mel = hubert.shape[0]
|
271 |
+
x_multiply = 8
|
272 |
+
if len_mel % x_multiply == 0:
|
273 |
+
num_to_pad = 0
|
274 |
+
else:
|
275 |
+
num_to_pad = x_multiply - len_mel % x_multiply
|
276 |
+
hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
|
277 |
+
return hubert
|
278 |
+
|
279 |
+
def get_mfcc(self, wav16k_name):
|
280 |
+
from utils.audio import librosa_wav2mfcc
|
281 |
+
hparams['fft_size'] = 1200
|
282 |
+
hparams['win_size'] = 1200
|
283 |
+
hparams['hop_size'] = 480
|
284 |
+
hparams['audio_num_mel_bins'] = 80
|
285 |
+
hparams['fmin'] = 80
|
286 |
+
hparams['fmax'] = 12000
|
287 |
+
hparams['audio_sample_rate'] = 24000
|
288 |
+
mfcc = librosa_wav2mfcc(wav16k_name,
|
289 |
+
fft_size=hparams['fft_size'],
|
290 |
+
hop_size=hparams['hop_size'],
|
291 |
+
win_length=hparams['win_size'],
|
292 |
+
num_mels=hparams['audio_num_mel_bins'],
|
293 |
+
fmin=hparams['fmin'],
|
294 |
+
fmax=hparams['fmax'],
|
295 |
+
sample_rate=hparams['audio_sample_rate'],
|
296 |
+
center=True)
|
297 |
+
mfcc = np.array(mfcc).reshape([-1, 13])
|
298 |
+
len_mel = mfcc.shape[0]
|
299 |
+
x_multiply = 8
|
300 |
+
if len_mel % x_multiply == 0:
|
301 |
+
num_to_pad = 0
|
302 |
+
else:
|
303 |
+
num_to_pad = x_multiply - len_mel % x_multiply
|
304 |
+
mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
|
305 |
+
return mfcc
|
306 |
+
|
307 |
+
@torch.no_grad()
|
308 |
+
def forward_audio2secc(self, batch, inp=None):
|
309 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
310 |
+
# audio-to-exp
|
311 |
+
ret = {}
|
312 |
+
pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'],)
|
313 |
+
print("| audio-to-motion finished")
|
314 |
+
if pred.shape[-1] == 144:
|
315 |
+
id = ret['pred'][0][:,:80]
|
316 |
+
exp = ret['pred'][0][:,80:]
|
317 |
+
else:
|
318 |
+
id = batch['id']
|
319 |
+
exp = ret['pred'][0]
|
320 |
+
if len(id) < len(exp): # happens when use ICL
|
321 |
+
id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
|
322 |
+
batch['id'] = id
|
323 |
+
batch['exp'] = exp
|
324 |
+
else:
|
325 |
+
drv_motion_coeff_dict = self.drv_motion_coeff_dict
|
326 |
+
batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
|
327 |
+
|
328 |
+
batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
|
329 |
+
if self.use_icl_audio2motion:
|
330 |
+
self.audio2secc_model.empty_context()
|
331 |
+
return batch
|
332 |
+
|
333 |
+
@torch.no_grad()
|
334 |
+
def get_driving_motion(self, id, exp, euler, trans, batch, inp):
|
335 |
+
zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
|
336 |
+
zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
|
337 |
+
# render the secc given the id,exp
|
338 |
+
with torch.no_grad():
|
339 |
+
chunk_size = 50
|
340 |
+
drv_secc_color_lst = []
|
341 |
+
num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
|
342 |
+
for i in tqdm.trange(num_iters, desc="rendering drv secc"):
|
343 |
+
torch.cuda.empty_cache()
|
344 |
+
face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
|
345 |
+
drv_secc_color_lst.append(drv_secc_color.cpu())
|
346 |
+
drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
|
347 |
+
_, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
|
348 |
+
_, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
|
349 |
+
batch['drv_secc'] = drv_secc_colors.cuda()
|
350 |
+
batch['src_secc'] = src_secc_color.cuda()
|
351 |
+
batch['cano_secc'] = cano_secc_color.cuda()
|
352 |
+
|
353 |
+
# blinking secc
|
354 |
+
if inp['blink_mode'] == 'period':
|
355 |
+
period = 5 # second
|
356 |
+
|
357 |
+
for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
|
358 |
+
if i % (25*period) == 0:
|
359 |
+
blink_dur_frames = random.randint(8, 12)
|
360 |
+
for offset in range(blink_dur_frames):
|
361 |
+
j = offset + i
|
362 |
+
if j >= len(drv_secc_colors)-1: break
|
363 |
+
def blink_percent_fn(t, T):
|
364 |
+
return -4/T**2 * t**2 + 4/T * t
|
365 |
+
blink_percent = blink_percent_fn(offset, blink_dur_frames)
|
366 |
+
secc = batch['drv_secc'][j]
|
367 |
+
out_secc = blink_eye_for_secc(secc, blink_percent)
|
368 |
+
out_secc = out_secc.cuda()
|
369 |
+
batch['drv_secc'][j] = out_secc
|
370 |
+
|
371 |
+
# get the drv_kp for torso model, using the transformed trajectory
|
372 |
+
drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
|
373 |
+
|
374 |
+
drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
|
375 |
+
batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
|
376 |
+
return batch
|
377 |
+
|
378 |
+
@torch.no_grad()
|
379 |
+
def forward_secc2video(self, batch, inp=None):
|
380 |
+
num_frames = len(batch['drv_secc'])
|
381 |
+
camera = batch['camera']
|
382 |
+
src_kps = batch['src_kp']
|
383 |
+
drv_kps = batch['drv_kp']
|
384 |
+
cano_secc_color = batch['cano_secc']
|
385 |
+
src_secc_color = batch['src_secc']
|
386 |
+
drv_secc_colors = batch['drv_secc']
|
387 |
+
ref_img_gt = batch['ref_gt_img']
|
388 |
+
ref_img_head = batch['ref_head_img']
|
389 |
+
ref_torso_img = batch['ref_torso_img']
|
390 |
+
bg_img = batch['bg_img']
|
391 |
+
segmap = batch['segmap']
|
392 |
+
|
393 |
+
# smooth torso drv_kp
|
394 |
+
torso_smo_ksize = 7
|
395 |
+
drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
|
396 |
+
|
397 |
+
# forward renderer
|
398 |
+
img_raw_lst = []
|
399 |
+
img_lst = []
|
400 |
+
depth_img_lst = []
|
401 |
+
with torch.no_grad():
|
402 |
+
for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
|
403 |
+
kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
|
404 |
+
kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
|
405 |
+
cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
|
406 |
+
'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
|
407 |
+
'kp_s': kp_src, 'kp_d': kp_drv}
|
408 |
+
if i == 0:
|
409 |
+
gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
|
410 |
+
else:
|
411 |
+
gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
|
412 |
+
img_lst.append(gen_output['image'])
|
413 |
+
img_raw_lst.append(gen_output['image_raw'])
|
414 |
+
depth_img_lst.append(gen_output['image_depth'])
|
415 |
+
|
416 |
+
# save demo video
|
417 |
+
depth_imgs = torch.cat(depth_img_lst)
|
418 |
+
imgs = torch.cat(img_lst)
|
419 |
+
imgs_raw = torch.cat(img_raw_lst)
|
420 |
+
secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
|
421 |
+
|
422 |
+
if inp['out_mode'] == 'concat_debug':
|
423 |
+
secc_img = secc_img.cpu()
|
424 |
+
secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
|
425 |
+
|
426 |
+
depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
|
427 |
+
depth_img = depth_img.repeat([1,3,1,1])
|
428 |
+
depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
|
429 |
+
depth_img = depth_img * 2 - 1
|
430 |
+
depth_img = depth_img.clamp(-1,1)
|
431 |
+
|
432 |
+
secc_img = secc_img / 127.5 - 1
|
433 |
+
secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
|
434 |
+
imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
|
435 |
+
elif inp['out_mode'] == 'final':
|
436 |
+
imgs = imgs.cpu()
|
437 |
+
elif inp['out_mode'] == 'debug':
|
438 |
+
raise NotImplementedError("to do: save separate videos")
|
439 |
+
imgs = imgs.clamp(-1,1)
|
440 |
+
|
441 |
+
import imageio
|
442 |
+
debug_name = 'demo.mp4'
|
443 |
+
out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
|
444 |
+
writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
|
445 |
+
|
446 |
+
for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
|
447 |
+
writer.append_data(out_imgs[i])
|
448 |
+
writer.close()
|
449 |
+
|
450 |
+
out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
|
451 |
+
try:
|
452 |
+
os.makedirs(os.path.dirname(out_fname), exist_ok=True)
|
453 |
+
except: pass
|
454 |
+
if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
|
455 |
+
os.system(f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}")
|
456 |
+
os.system(f"rm {debug_name}")
|
457 |
+
os.system(f"rm {self.wav16k_name}")
|
458 |
+
else:
|
459 |
+
ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
|
460 |
+
if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
|
461 |
+
os.system(f"mv {debug_name} {out_fname}")
|
462 |
+
print(f"Saved at {out_fname}")
|
463 |
+
return out_fname
|
464 |
+
|
465 |
+
@torch.no_grad()
|
466 |
+
def forward_system(self, batch, inp):
|
467 |
+
self.forward_audio2secc(batch, inp)
|
468 |
+
out_fname = self.forward_secc2video(batch, inp)
|
469 |
+
return out_fname
|
470 |
+
|
471 |
+
@classmethod
|
472 |
+
def example_run(cls, inp=None):
|
473 |
+
inp_tmp = {
|
474 |
+
'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
|
475 |
+
'src_image_name': 'data/raw/val_imgs/Macron.png'
|
476 |
+
}
|
477 |
+
if inp is not None:
|
478 |
+
inp_tmp.update(inp)
|
479 |
+
inp = inp_tmp
|
480 |
+
|
481 |
+
infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
|
482 |
+
infer_instance.infer_once(inp)
|
483 |
+
|
484 |
+
##############
|
485 |
+
# IO-related
|
486 |
+
##############
|
487 |
+
def save_wav16k(self, audio_name):
|
488 |
+
supported_types = ('.wav', '.mp3', '.mp4', '.avi')
|
489 |
+
assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
|
490 |
+
wav16k_name = audio_name[:-4] + '_16k.wav'
|
491 |
+
self.wav16k_name = wav16k_name
|
492 |
+
extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
|
493 |
+
os.system(extract_wav_cmd)
|
494 |
+
print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
|
495 |
+
|
496 |
+
def get_f0(self, wav16k_name):
|
497 |
+
from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
|
498 |
+
wav, mel = extract_mel_from_fname(self.wav16k_name)
|
499 |
+
f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
|
500 |
+
f0 = f0.reshape([-1,1])
|
501 |
+
return f0
|
502 |
+
|
503 |
+
if __name__ == '__main__':
|
504 |
+
import argparse, glob, tqdm
|
505 |
+
parser = argparse.ArgumentParser()
|
506 |
+
parser.add_argument("--a2m_ckpt", default='checkpoints/240126_real3dportrait_orig/audio2secc_vae', type=str)
|
507 |
+
parser.add_argument("--head_ckpt", default='', type=str)
|
508 |
+
parser.add_argument("--torso_ckpt", default='checkpoints/240126_real3dportrait_orig/secc2plane_torso_orig', type=str)
|
509 |
+
parser.add_argument("--src_img", default='', type=str) # data/raw/examples/Macron.png
|
510 |
+
parser.add_argument("--bg_img", default='', type=str) # data/raw/examples/bg.png
|
511 |
+
parser.add_argument("--drv_aud", default='', type=str) # data/raw/examples/Obama_5s.wav
|
512 |
+
parser.add_argument("--drv_pose", default='static', type=str) # data/raw/examples/May_5s.mp4
|
513 |
+
parser.add_argument("--blink_mode", default='none', type=str) # none | period
|
514 |
+
parser.add_argument("--temperature", default=0.2, type=float) # sampling temperature in audio2motion, higher -> more diverse, less accurate
|
515 |
+
parser.add_argument("--mouth_amp", default=0.45, type=float) # scale of predicted mouth, enabled in audio-driven
|
516 |
+
parser.add_argument("--head_torso_threshold", default=0.9, type=float, help="0.1~1.0, turn up this value if the hair is translucent")
|
517 |
+
parser.add_argument("--out_name", default='') # output filename
|
518 |
+
parser.add_argument("--out_mode", default='final') # final: only output talking head video; concat_debug: talking head with internel features
|
519 |
+
parser.add_argument("--map_to_init_pose", default='True') # whether to map the pose of first frame to source image
|
520 |
+
parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
|
521 |
+
|
522 |
+
args = parser.parse_args()
|
523 |
+
|
524 |
+
inp = {
|
525 |
+
'a2m_ckpt': args.a2m_ckpt,
|
526 |
+
'head_ckpt': args.head_ckpt,
|
527 |
+
'torso_ckpt': args.torso_ckpt,
|
528 |
+
'src_image_name': args.src_img,
|
529 |
+
'bg_image_name': args.bg_img,
|
530 |
+
'drv_audio_name': args.drv_aud,
|
531 |
+
'drv_pose_name': args.drv_pose,
|
532 |
+
'blink_mode': args.blink_mode,
|
533 |
+
'temperature': args.temperature,
|
534 |
+
'mouth_amp': args.mouth_amp,
|
535 |
+
'out_name': args.out_name,
|
536 |
+
'out_mode': args.out_mode,
|
537 |
+
'map_to_init_pose': args.map_to_init_pose,
|
538 |
+
'head_torso_threshold': args.head_torso_threshold,
|
539 |
+
'seed': args.seed,
|
540 |
+
}
|
541 |
+
|
542 |
+
GeneFace2Infer.example_run(inp)
|
insta.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#conda create -n real3dportrait python=3.9
|
3 |
+
#conda activate real3dportrait
|
4 |
+
conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
|
5 |
+
|
6 |
+
### We recommend torch2.0.1+cuda11.7.
|
7 |
+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
|
8 |
+
|
9 |
+
# Build from source, it may take a long time (Proxy is recommended if encountering the time-out problem)
|
10 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
11 |
+
|
12 |
+
# MMCV for some network structure
|
13 |
+
pip install cython
|
14 |
+
pip install openmim==0.3.9
|
15 |
+
mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
|
16 |
+
|
17 |
+
# other dependencies
|
18 |
+
pip install -r docs/prepare_env/requirements.txt -v
|
modules/audio2motion/cnn_models.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def init_weights_func(m):
|
7 |
+
classname = m.__class__.__name__
|
8 |
+
if classname.find("Conv1d") != -1:
|
9 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
10 |
+
|
11 |
+
|
12 |
+
class LambdaLayer(nn.Module):
|
13 |
+
def __init__(self, lambd):
|
14 |
+
super(LambdaLayer, self).__init__()
|
15 |
+
self.lambd = lambd
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.lambd(x)
|
19 |
+
|
20 |
+
|
21 |
+
class LayerNorm(torch.nn.LayerNorm):
|
22 |
+
"""Layer normalization module.
|
23 |
+
:param int nout: output dim size
|
24 |
+
:param int dim: dimension to be normalized
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
28 |
+
"""Construct an LayerNorm object."""
|
29 |
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
30 |
+
self.dim = dim
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
"""Apply layer normalization.
|
34 |
+
:param torch.Tensor x: input tensor
|
35 |
+
:return: layer normalized tensor
|
36 |
+
:rtype torch.Tensor
|
37 |
+
"""
|
38 |
+
if self.dim == -1:
|
39 |
+
return super(LayerNorm, self).forward(x)
|
40 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class ResidualBlock(nn.Module):
|
45 |
+
"""Implements conv->PReLU->norm n-times"""
|
46 |
+
|
47 |
+
def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
|
48 |
+
c_multiple=2, ln_eps=1e-12, bias=False):
|
49 |
+
super(ResidualBlock, self).__init__()
|
50 |
+
|
51 |
+
if norm_type == 'bn':
|
52 |
+
norm_builder = lambda: nn.BatchNorm1d(channels)
|
53 |
+
elif norm_type == 'in':
|
54 |
+
norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
|
55 |
+
elif norm_type == 'gn':
|
56 |
+
norm_builder = lambda: nn.GroupNorm(8, channels)
|
57 |
+
elif norm_type == 'ln':
|
58 |
+
norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
|
59 |
+
else:
|
60 |
+
norm_builder = lambda: nn.Identity()
|
61 |
+
|
62 |
+
self.blocks = [
|
63 |
+
nn.Sequential(
|
64 |
+
norm_builder(),
|
65 |
+
nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
|
66 |
+
padding=(dilation * (kernel_size - 1)) // 2, bias=bias),
|
67 |
+
LambdaLayer(lambda x: x * kernel_size ** -0.5),
|
68 |
+
nn.GELU(),
|
69 |
+
nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, bias=bias),
|
70 |
+
)
|
71 |
+
for _ in range(n)
|
72 |
+
]
|
73 |
+
|
74 |
+
self.blocks = nn.ModuleList(self.blocks)
|
75 |
+
self.dropout = dropout
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
|
79 |
+
for b in self.blocks:
|
80 |
+
x_ = b(x)
|
81 |
+
if self.dropout > 0 and self.training:
|
82 |
+
x_ = F.dropout(x_, self.dropout, training=self.training)
|
83 |
+
x = x + x_
|
84 |
+
x = x * nonpadding
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class ConvBlocks(nn.Module):
|
89 |
+
"""Decodes the expanded phoneme encoding into spectrograms"""
|
90 |
+
|
91 |
+
def __init__(self, channels, out_dims, dilations, kernel_size,
|
92 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
93 |
+
dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, bias=False):
|
94 |
+
super(ConvBlocks, self).__init__()
|
95 |
+
self.is_BTC = is_BTC
|
96 |
+
self.res_blocks = nn.Sequential(
|
97 |
+
*[ResidualBlock(channels, kernel_size, d,
|
98 |
+
n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
|
99 |
+
dropout=dropout, ln_eps=ln_eps, bias=bias)
|
100 |
+
for d in dilations],
|
101 |
+
)
|
102 |
+
if norm_type == 'bn':
|
103 |
+
norm = nn.BatchNorm1d(channels)
|
104 |
+
elif norm_type == 'in':
|
105 |
+
norm = nn.InstanceNorm1d(channels, affine=True)
|
106 |
+
elif norm_type == 'gn':
|
107 |
+
norm = nn.GroupNorm(8, channels)
|
108 |
+
elif norm_type == 'ln':
|
109 |
+
norm = LayerNorm(channels, dim=1, eps=ln_eps)
|
110 |
+
self.last_norm = norm
|
111 |
+
self.post_net1 = nn.Conv1d(channels, out_dims, kernel_size=3, padding=1, bias=bias)
|
112 |
+
if init_weights:
|
113 |
+
self.apply(init_weights_func)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
"""
|
117 |
+
|
118 |
+
:param x: [B, T, H]
|
119 |
+
:return: [B, T, H]
|
120 |
+
"""
|
121 |
+
if self.is_BTC:
|
122 |
+
x = x.transpose(1, 2) # [B, C, T]
|
123 |
+
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
|
124 |
+
x = self.res_blocks(x) * nonpadding
|
125 |
+
x = self.last_norm(x) * nonpadding
|
126 |
+
x = self.post_net1(x) * nonpadding
|
127 |
+
if self.is_BTC:
|
128 |
+
x = x.transpose(1, 2)
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
class SeqLevelConvolutionalModel(nn.Module):
|
133 |
+
def __init__(self, out_dim=64, dropout=0.5, audio_feat_type='ppg', backbone_type='unet', norm_type='bn'):
|
134 |
+
nn.Module.__init__(self)
|
135 |
+
self.audio_feat_type = audio_feat_type
|
136 |
+
if audio_feat_type == 'ppg':
|
137 |
+
self.audio_encoder = nn.Sequential(*[
|
138 |
+
nn.Conv1d(29, 48, 3, 1, 1, bias=False),
|
139 |
+
nn.BatchNorm1d(48) if norm_type=='bn' else LayerNorm(48, dim=1),
|
140 |
+
nn.GELU(),
|
141 |
+
nn.Conv1d(48, 48, 3, 1, 1, bias=False)
|
142 |
+
])
|
143 |
+
self.energy_encoder = nn.Sequential(*[
|
144 |
+
nn.Conv1d(1, 16, 3, 1, 1, bias=False),
|
145 |
+
nn.BatchNorm1d(16) if norm_type=='bn' else LayerNorm(16, dim=1),
|
146 |
+
nn.GELU(),
|
147 |
+
nn.Conv1d(16, 16, 3, 1, 1, bias=False)
|
148 |
+
])
|
149 |
+
elif audio_feat_type == 'mel':
|
150 |
+
self.mel_encoder = nn.Sequential(*[
|
151 |
+
nn.Conv1d(80, 64, 3, 1, 1, bias=False),
|
152 |
+
nn.BatchNorm1d(64) if norm_type=='bn' else LayerNorm(64, dim=1),
|
153 |
+
nn.GELU(),
|
154 |
+
nn.Conv1d(64, 64, 3, 1, 1, bias=False)
|
155 |
+
])
|
156 |
+
else:
|
157 |
+
raise NotImplementedError("now only ppg or mel are supported!")
|
158 |
+
|
159 |
+
self.style_encoder = nn.Sequential(*[
|
160 |
+
nn.Linear(135, 256),
|
161 |
+
nn.GELU(),
|
162 |
+
nn.Linear(256, 256)
|
163 |
+
])
|
164 |
+
|
165 |
+
if backbone_type == 'resnet':
|
166 |
+
self.backbone = ResNetBackbone()
|
167 |
+
elif backbone_type == 'unet':
|
168 |
+
self.backbone = UNetBackbone()
|
169 |
+
elif backbone_type == 'resblocks':
|
170 |
+
self.backbone = ResBlocksBackbone()
|
171 |
+
else:
|
172 |
+
raise NotImplementedError("Now only resnet and unet are supported!")
|
173 |
+
|
174 |
+
self.out_layer = nn.Sequential(
|
175 |
+
nn.BatchNorm1d(512) if norm_type=='bn' else LayerNorm(512, dim=1),
|
176 |
+
nn.Conv1d(512, 64, 3, 1, 1, bias=False),
|
177 |
+
nn.PReLU(),
|
178 |
+
nn.Conv1d(64, out_dim, 3, 1, 1, bias=False)
|
179 |
+
)
|
180 |
+
self.feat_dropout = nn.Dropout(p=dropout)
|
181 |
+
|
182 |
+
@property
|
183 |
+
def device(self):
|
184 |
+
return self.backbone.parameters().__next__().device
|
185 |
+
|
186 |
+
def forward(self, batch, ret, log_dict=None):
|
187 |
+
style, x_mask = batch['style'].to(self.device), batch['x_mask'].to(self.device)
|
188 |
+
style_feat = self.style_encoder(style) # [B,C=135] => [B,C=128]
|
189 |
+
|
190 |
+
if self.audio_feat_type == 'ppg':
|
191 |
+
audio, energy = batch['audio'].to(self.device), batch['energy'].to(self.device)
|
192 |
+
audio_feat = self.audio_encoder(audio.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=29] => [B,T,C=48]
|
193 |
+
energy_feat = self.energy_encoder(energy.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=1] => [B,T,C=16]
|
194 |
+
feat = torch.cat([audio_feat, energy_feat], dim=2) # [B,T,C=48+16]
|
195 |
+
elif self.audio_feat_type == 'mel':
|
196 |
+
mel = batch['mel'].to(self.device)
|
197 |
+
feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=64]
|
198 |
+
|
199 |
+
feat, x_mask = self.backbone(x=feat, sty=style_feat, x_mask=x_mask)
|
200 |
+
|
201 |
+
out = self.out_layer(feat.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T//2,C=256] => [B,T//2,C=64]
|
202 |
+
|
203 |
+
ret['pred'] = out
|
204 |
+
ret['mask'] = x_mask
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class ResBlocksBackbone(nn.Module):
|
209 |
+
def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
|
210 |
+
super(ResBlocksBackbone,self).__init__()
|
211 |
+
self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
212 |
+
self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
213 |
+
self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
214 |
+
self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
215 |
+
self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
216 |
+
|
217 |
+
self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
|
218 |
+
self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
|
219 |
+
|
220 |
+
self.dropout = nn.Dropout(p=p_dropout)
|
221 |
+
|
222 |
+
def forward(self, x, sty, x_mask=1.):
|
223 |
+
"""
|
224 |
+
x: [B, T, C]
|
225 |
+
sty: [B, C=256]
|
226 |
+
x_mask: [B, T]
|
227 |
+
ret: [B, T/2, C]
|
228 |
+
"""
|
229 |
+
x = x.transpose(1, 2) # [B, C, T]
|
230 |
+
x_mask = x_mask[:, None, :] # [B, 1, T]
|
231 |
+
|
232 |
+
x = self.resblocks_0(x) * x_mask # [B, C, T]
|
233 |
+
|
234 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/2]
|
235 |
+
x = self.downsampler(x) * x_mask # [B, C, T/2]
|
236 |
+
x = self.resblocks_1(x) * x_mask # [B, C, T/2]
|
237 |
+
x = self.resblocks_2(x) * x_mask # [B, C, T/2]
|
238 |
+
|
239 |
+
x = self.dropout(x.transpose(1,2)).transpose(1,2)
|
240 |
+
sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/2]
|
241 |
+
x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/2]
|
242 |
+
|
243 |
+
x = self.resblocks_3(x) * x_mask # [B, C, T/2]
|
244 |
+
x = self.resblocks_4(x) * x_mask # [B, C, T/2]
|
245 |
+
|
246 |
+
x = x.transpose(1,2)
|
247 |
+
x_mask = x_mask.squeeze(1)
|
248 |
+
return x, x_mask
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
class ResNetBackbone(nn.Module):
|
253 |
+
def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
|
254 |
+
super(ResNetBackbone,self).__init__()
|
255 |
+
self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
256 |
+
self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
257 |
+
self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
258 |
+
self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
259 |
+
self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
260 |
+
|
261 |
+
self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
|
262 |
+
self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
|
263 |
+
|
264 |
+
self.dropout = nn.Dropout(p=p_dropout)
|
265 |
+
|
266 |
+
def forward(self, x, sty, x_mask=1.):
|
267 |
+
"""
|
268 |
+
x: [B, T, C]
|
269 |
+
sty: [B, C=256]
|
270 |
+
x_mask: [B, T]
|
271 |
+
ret: [B, T/2, C]
|
272 |
+
"""
|
273 |
+
x = x.transpose(1, 2) # [B, C, T]
|
274 |
+
x_mask = x_mask[:, None, :] # [B, 1, T]
|
275 |
+
|
276 |
+
x = self.resblocks_0(x) * x_mask # [B, C, T]
|
277 |
+
|
278 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/2]
|
279 |
+
x = self.downsampler(x) * x_mask # [B, C, T/2]
|
280 |
+
x = self.resblocks_1(x) * x_mask # [B, C, T/2]
|
281 |
+
|
282 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/4]
|
283 |
+
x = self.downsampler(x) * x_mask # [B, C, T/4]
|
284 |
+
x = self.resblocks_2(x) * x_mask # [B, C, T/4]
|
285 |
+
|
286 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/8]
|
287 |
+
x = self.downsampler(x) * x_mask # [B, C, T/8]
|
288 |
+
x = self.dropout(x.transpose(1,2)).transpose(1,2)
|
289 |
+
sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
|
290 |
+
x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
|
291 |
+
x = self.resblocks_3(x) * x_mask # [B, C, T/8]
|
292 |
+
|
293 |
+
x_mask = self.upsampler(x_mask) # [B, 1, T/2]
|
294 |
+
x = self.upsampler(x) * x_mask # [B, C, T/2]
|
295 |
+
x = self.resblocks_4(x) * x_mask # [B, C, T/2]
|
296 |
+
|
297 |
+
x = x.transpose(1,2)
|
298 |
+
x_mask = x_mask.squeeze(1)
|
299 |
+
return x, x_mask
|
300 |
+
|
301 |
+
|
302 |
+
class UNetBackbone(nn.Module):
|
303 |
+
def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
|
304 |
+
super(UNetBackbone, self).__init__()
|
305 |
+
self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
306 |
+
self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
307 |
+
self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*8, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
308 |
+
self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
|
309 |
+
self.resblocks_4 = ConvBlocks(channels=768, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [768 = c3(512) + c2(256)]
|
310 |
+
self.resblocks_5 = ConvBlocks(channels=640, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [640 = c4(512) + c1(128)]
|
311 |
+
|
312 |
+
self.downsampler = nn.Upsample(scale_factor=0.5, mode='linear')
|
313 |
+
self.upsampler = nn.Upsample(scale_factor=2, mode='linear')
|
314 |
+
self.dropout = nn.Dropout(p=p_dropout)
|
315 |
+
|
316 |
+
def forward(self, x, sty, x_mask=1.):
|
317 |
+
"""
|
318 |
+
x: [B, T, C]
|
319 |
+
sty: [B, C=256]
|
320 |
+
x_mask: [B, T]
|
321 |
+
ret: [B, T/2, C]
|
322 |
+
"""
|
323 |
+
x = x.transpose(1, 2) # [B, C, T]
|
324 |
+
x_mask = x_mask[:, None, :] # [B, 1, T]
|
325 |
+
|
326 |
+
x0 = self.resblocks_0(x) * x_mask # [B, C, T]
|
327 |
+
|
328 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/2]
|
329 |
+
x = self.downsampler(x0) * x_mask # [B, C, T/2]
|
330 |
+
x1 = self.resblocks_1(x) * x_mask # [B, C, T/2]
|
331 |
+
|
332 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/4]
|
333 |
+
x = self.downsampler(x1) * x_mask # [B, C, T/4]
|
334 |
+
x2 = self.resblocks_2(x) * x_mask # [B, C, T/4]
|
335 |
+
|
336 |
+
x_mask = self.downsampler(x_mask) # [B, 1, T/8]
|
337 |
+
x = self.downsampler(x2) * x_mask # [B, C, T/8]
|
338 |
+
x = self.dropout(x.transpose(1,2)).transpose(1,2)
|
339 |
+
sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
|
340 |
+
x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
|
341 |
+
x3 = self.resblocks_3(x) * x_mask # [B, C, T/8]
|
342 |
+
|
343 |
+
x_mask = self.upsampler(x_mask) # [B, 1, T/4]
|
344 |
+
x = self.upsampler(x3) * x_mask # [B, C, T/4]
|
345 |
+
x = torch.cat([x, self.dropout(x2.transpose(1,2)).transpose(1,2)], dim=1) #
|
346 |
+
x4 = self.resblocks_4(x) * x_mask # [B, C, T/4]
|
347 |
+
|
348 |
+
x_mask = self.upsampler(x_mask) # [B, 1, T/2]
|
349 |
+
x = self.upsampler(x4) * x_mask # [B, C, T/2]
|
350 |
+
x = torch.cat([x, self.dropout(x1.transpose(1,2)).transpose(1,2)], dim=1)
|
351 |
+
x5 = self.resblocks_5(x) * x_mask # [B, C, T/2]
|
352 |
+
|
353 |
+
x = x5.transpose(1,2)
|
354 |
+
x_mask = x_mask.squeeze(1)
|
355 |
+
return x, x_mask
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == '__main__':
|
359 |
+
pass
|
modules/audio2motion/flow_base.py
ADDED
@@ -0,0 +1,838 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import scipy
|
2 |
+
from scipy import linalg
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import modules.audio2motion.utils as utils
|
9 |
+
from modules.audio2motion.transformer_models import FFTBlocks
|
10 |
+
from utils.commons.hparams import hparams
|
11 |
+
|
12 |
+
|
13 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
14 |
+
n_channels_int = n_channels[0]
|
15 |
+
in_act = input_a + input_b
|
16 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
17 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
18 |
+
acts = t_act * s_act
|
19 |
+
return acts
|
20 |
+
|
21 |
+
class WN(torch.nn.Module):
|
22 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0,
|
23 |
+
p_dropout=0, share_cond_layers=False):
|
24 |
+
super(WN, self).__init__()
|
25 |
+
assert (kernel_size % 2 == 1)
|
26 |
+
assert (hidden_channels % 2 == 0)
|
27 |
+
self.hidden_channels = hidden_channels
|
28 |
+
self.kernel_size = kernel_size
|
29 |
+
self.dilation_rate = dilation_rate
|
30 |
+
self.n_layers = n_layers
|
31 |
+
self.gin_channels = gin_channels
|
32 |
+
self.p_dropout = p_dropout
|
33 |
+
self.share_cond_layers = share_cond_layers
|
34 |
+
|
35 |
+
self.in_layers = torch.nn.ModuleList()
|
36 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
37 |
+
|
38 |
+
self.drop = nn.Dropout(p_dropout)
|
39 |
+
|
40 |
+
self.use_adapters = hparams.get("use_adapters", False)
|
41 |
+
if self.use_adapters:
|
42 |
+
self.adapter_layers = torch.nn.ModuleList()
|
43 |
+
|
44 |
+
if gin_channels != 0 and not share_cond_layers:
|
45 |
+
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
46 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
47 |
+
|
48 |
+
for i in range(n_layers):
|
49 |
+
dilation = dilation_rate ** i
|
50 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
51 |
+
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
|
52 |
+
dilation=dilation, padding=padding)
|
53 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
54 |
+
self.in_layers.append(in_layer)
|
55 |
+
|
56 |
+
# last one is not necessary
|
57 |
+
if i < n_layers - 1:
|
58 |
+
res_skip_channels = 2 * hidden_channels
|
59 |
+
else:
|
60 |
+
res_skip_channels = hidden_channels
|
61 |
+
|
62 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
63 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
64 |
+
self.res_skip_layers.append(res_skip_layer)
|
65 |
+
|
66 |
+
if self.use_adapters:
|
67 |
+
adapter_layer = MlpAdapter(in_out_dim=res_skip_channels, hid_dim=res_skip_channels//4)
|
68 |
+
self.adapter_layers.append(adapter_layer)
|
69 |
+
|
70 |
+
def forward(self, x, x_mask=None, g=None, **kwargs):
|
71 |
+
output = torch.zeros_like(x)
|
72 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
73 |
+
|
74 |
+
if g is not None and not self.share_cond_layers:
|
75 |
+
g = self.cond_layer(g)
|
76 |
+
|
77 |
+
for i in range(self.n_layers):
|
78 |
+
x_in = self.in_layers[i](x)
|
79 |
+
x_in = self.drop(x_in)
|
80 |
+
if g is not None:
|
81 |
+
cond_offset = i * 2 * self.hidden_channels
|
82 |
+
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
83 |
+
else:
|
84 |
+
g_l = torch.zeros_like(x_in)
|
85 |
+
|
86 |
+
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
87 |
+
|
88 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
89 |
+
if self.use_adapters:
|
90 |
+
res_skip_acts = self.adapter_layers[i](res_skip_acts.transpose(1,2)).transpose(1,2)
|
91 |
+
if i < self.n_layers - 1:
|
92 |
+
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
|
93 |
+
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
94 |
+
else:
|
95 |
+
output = output + res_skip_acts
|
96 |
+
return output * x_mask
|
97 |
+
|
98 |
+
def remove_weight_norm(self):
|
99 |
+
def remove_weight_norm(m):
|
100 |
+
try:
|
101 |
+
nn.utils.remove_weight_norm(m)
|
102 |
+
except ValueError: # this module didn't have weight norm
|
103 |
+
return
|
104 |
+
|
105 |
+
self.apply(remove_weight_norm)
|
106 |
+
|
107 |
+
def enable_adapters(self):
|
108 |
+
if not self.use_adapters:
|
109 |
+
return
|
110 |
+
for adapter_layer in self.adapter_layers:
|
111 |
+
adapter_layer.enable()
|
112 |
+
|
113 |
+
def disable_adapters(self):
|
114 |
+
if not self.use_adapters:
|
115 |
+
return
|
116 |
+
for adapter_layer in self.adapter_layers:
|
117 |
+
adapter_layer.disable()
|
118 |
+
|
119 |
+
class Permute(nn.Module):
|
120 |
+
def __init__(self, *args):
|
121 |
+
super(Permute, self).__init__()
|
122 |
+
self.args = args
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
return x.permute(self.args)
|
126 |
+
|
127 |
+
|
128 |
+
class LayerNorm(nn.Module):
|
129 |
+
def __init__(self, channels, eps=1e-4):
|
130 |
+
super().__init__()
|
131 |
+
self.channels = channels
|
132 |
+
self.eps = eps
|
133 |
+
|
134 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
135 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
n_dims = len(x.shape)
|
139 |
+
mean = torch.mean(x, 1, keepdim=True)
|
140 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
141 |
+
|
142 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
143 |
+
|
144 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
145 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class ConvReluNorm(nn.Module):
|
150 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
151 |
+
super().__init__()
|
152 |
+
self.in_channels = in_channels
|
153 |
+
self.hidden_channels = hidden_channels
|
154 |
+
self.out_channels = out_channels
|
155 |
+
self.kernel_size = kernel_size
|
156 |
+
self.n_layers = n_layers
|
157 |
+
self.p_dropout = p_dropout
|
158 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
159 |
+
|
160 |
+
self.conv_layers = nn.ModuleList()
|
161 |
+
self.norm_layers = nn.ModuleList()
|
162 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
163 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
164 |
+
self.relu_drop = nn.Sequential(
|
165 |
+
nn.ReLU(),
|
166 |
+
nn.Dropout(p_dropout))
|
167 |
+
for _ in range(n_layers - 1):
|
168 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
169 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
170 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
171 |
+
self.proj.weight.data.zero_()
|
172 |
+
self.proj.bias.data.zero_()
|
173 |
+
|
174 |
+
def forward(self, x, x_mask):
|
175 |
+
x_org = x
|
176 |
+
for i in range(self.n_layers):
|
177 |
+
x = self.conv_layers[i](x * x_mask)
|
178 |
+
x = self.norm_layers[i](x)
|
179 |
+
x = self.relu_drop(x)
|
180 |
+
x = x_org + self.proj(x)
|
181 |
+
return x * x_mask
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
class ActNorm(nn.Module):
|
186 |
+
def __init__(self, channels, ddi=False, **kwargs):
|
187 |
+
super().__init__()
|
188 |
+
self.channels = channels
|
189 |
+
self.initialized = not ddi
|
190 |
+
|
191 |
+
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
192 |
+
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
193 |
+
|
194 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
195 |
+
if x_mask is None:
|
196 |
+
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
197 |
+
x_len = torch.sum(x_mask, [1, 2])
|
198 |
+
if not self.initialized:
|
199 |
+
self.initialize(x, x_mask)
|
200 |
+
self.initialized = True
|
201 |
+
|
202 |
+
if reverse:
|
203 |
+
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
204 |
+
logdet = torch.sum(-self.logs) * x_len
|
205 |
+
else:
|
206 |
+
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
207 |
+
logdet = torch.sum(self.logs) * x_len # [b]
|
208 |
+
return z, logdet
|
209 |
+
|
210 |
+
def store_inverse(self):
|
211 |
+
pass
|
212 |
+
|
213 |
+
def set_ddi(self, ddi):
|
214 |
+
self.initialized = not ddi
|
215 |
+
|
216 |
+
def initialize(self, x, x_mask):
|
217 |
+
with torch.no_grad():
|
218 |
+
denom = torch.sum(x_mask, [0, 2])
|
219 |
+
m = torch.sum(x * x_mask, [0, 2]) / denom
|
220 |
+
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
221 |
+
v = m_sq - (m ** 2)
|
222 |
+
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
223 |
+
|
224 |
+
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
225 |
+
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
226 |
+
|
227 |
+
self.bias.data.copy_(bias_init)
|
228 |
+
self.logs.data.copy_(logs_init)
|
229 |
+
|
230 |
+
|
231 |
+
class InvConvNear(nn.Module):
|
232 |
+
def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
|
233 |
+
super().__init__()
|
234 |
+
assert (n_split % 2 == 0)
|
235 |
+
self.channels = channels
|
236 |
+
self.n_split = n_split
|
237 |
+
self.n_sqz = n_sqz
|
238 |
+
self.no_jacobian = no_jacobian
|
239 |
+
|
240 |
+
w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
|
241 |
+
if torch.det(w_init) < 0:
|
242 |
+
w_init[:, 0] = -1 * w_init[:, 0]
|
243 |
+
self.lu = lu
|
244 |
+
if lu:
|
245 |
+
# LU decomposition can slightly speed up the inverse
|
246 |
+
np_p, np_l, np_u = linalg.lu(w_init)
|
247 |
+
np_s = np.diag(np_u)
|
248 |
+
np_sign_s = np.sign(np_s)
|
249 |
+
np_log_s = np.log(np.abs(np_s))
|
250 |
+
np_u = np.triu(np_u, k=1)
|
251 |
+
l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
|
252 |
+
eye = np.eye(*w_init.shape, dtype=float)
|
253 |
+
|
254 |
+
self.register_buffer('p', torch.Tensor(np_p.astype(float)))
|
255 |
+
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
|
256 |
+
self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
|
257 |
+
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
|
258 |
+
self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
|
259 |
+
self.register_buffer('l_mask', torch.Tensor(l_mask))
|
260 |
+
self.register_buffer('eye', torch.Tensor(eye))
|
261 |
+
else:
|
262 |
+
self.weight = nn.Parameter(w_init)
|
263 |
+
|
264 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
265 |
+
b, c, t = x.size()
|
266 |
+
assert (c % self.n_split == 0)
|
267 |
+
if x_mask is None:
|
268 |
+
x_mask = 1
|
269 |
+
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
270 |
+
else:
|
271 |
+
x_len = torch.sum(x_mask, [1, 2])
|
272 |
+
|
273 |
+
x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
|
274 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
|
275 |
+
|
276 |
+
if self.lu:
|
277 |
+
self.weight, log_s = self._get_weight()
|
278 |
+
logdet = log_s.sum()
|
279 |
+
logdet = logdet * (c / self.n_split) * x_len
|
280 |
+
else:
|
281 |
+
logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
|
282 |
+
|
283 |
+
if reverse:
|
284 |
+
if hasattr(self, "weight_inv"):
|
285 |
+
weight = self.weight_inv
|
286 |
+
else:
|
287 |
+
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
288 |
+
logdet = -logdet
|
289 |
+
else:
|
290 |
+
weight = self.weight
|
291 |
+
if self.no_jacobian:
|
292 |
+
logdet = 0
|
293 |
+
|
294 |
+
weight = weight.view(self.n_split, self.n_split, 1, 1)
|
295 |
+
z = F.conv2d(x, weight)
|
296 |
+
|
297 |
+
z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
|
298 |
+
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
|
299 |
+
return z, logdet
|
300 |
+
|
301 |
+
def _get_weight(self):
|
302 |
+
l, log_s, u = self.l, self.log_s, self.u
|
303 |
+
l = l * self.l_mask + self.eye
|
304 |
+
u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
|
305 |
+
weight = torch.matmul(self.p, torch.matmul(l, u))
|
306 |
+
return weight, log_s
|
307 |
+
|
308 |
+
def store_inverse(self):
|
309 |
+
weight, _ = self._get_weight()
|
310 |
+
self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
|
311 |
+
|
312 |
+
|
313 |
+
class InvConv(nn.Module):
|
314 |
+
def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
|
315 |
+
super().__init__()
|
316 |
+
w_shape = [channels, channels]
|
317 |
+
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
|
318 |
+
LU_decomposed = lu
|
319 |
+
if not LU_decomposed:
|
320 |
+
# Sample a random orthogonal matrix:
|
321 |
+
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
|
322 |
+
else:
|
323 |
+
np_p, np_l, np_u = linalg.lu(w_init)
|
324 |
+
np_s = np.diag(np_u)
|
325 |
+
np_sign_s = np.sign(np_s)
|
326 |
+
np_log_s = np.log(np.abs(np_s))
|
327 |
+
np_u = np.triu(np_u, k=1)
|
328 |
+
l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
|
329 |
+
eye = np.eye(*w_shape, dtype=float)
|
330 |
+
|
331 |
+
self.register_buffer('p', torch.Tensor(np_p.astype(float)))
|
332 |
+
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
|
333 |
+
self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
|
334 |
+
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
|
335 |
+
self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
|
336 |
+
self.l_mask = torch.Tensor(l_mask)
|
337 |
+
self.eye = torch.Tensor(eye)
|
338 |
+
self.w_shape = w_shape
|
339 |
+
self.LU = LU_decomposed
|
340 |
+
self.weight = None
|
341 |
+
|
342 |
+
def get_weight(self, device, reverse):
|
343 |
+
w_shape = self.w_shape
|
344 |
+
self.p = self.p.to(device)
|
345 |
+
self.sign_s = self.sign_s.to(device)
|
346 |
+
self.l_mask = self.l_mask.to(device)
|
347 |
+
self.eye = self.eye.to(device)
|
348 |
+
l = self.l * self.l_mask + self.eye
|
349 |
+
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
|
350 |
+
dlogdet = self.log_s.sum()
|
351 |
+
if not reverse:
|
352 |
+
w = torch.matmul(self.p, torch.matmul(l, u))
|
353 |
+
else:
|
354 |
+
l = torch.inverse(l.double()).float()
|
355 |
+
u = torch.inverse(u.double()).float()
|
356 |
+
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
|
357 |
+
return w.view(w_shape[0], w_shape[1], 1), dlogdet
|
358 |
+
|
359 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
360 |
+
"""
|
361 |
+
log-det = log|abs(|W|)| * pixels
|
362 |
+
"""
|
363 |
+
b, c, t = x.size()
|
364 |
+
if x_mask is None:
|
365 |
+
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
366 |
+
else:
|
367 |
+
x_len = torch.sum(x_mask, [1, 2])
|
368 |
+
logdet = 0
|
369 |
+
if not reverse:
|
370 |
+
weight, dlogdet = self.get_weight(x.device, reverse)
|
371 |
+
z = F.conv1d(x, weight)
|
372 |
+
if logdet is not None:
|
373 |
+
logdet = logdet + dlogdet * x_len
|
374 |
+
return z, logdet
|
375 |
+
else:
|
376 |
+
if self.weight is None:
|
377 |
+
weight, dlogdet = self.get_weight(x.device, reverse)
|
378 |
+
else:
|
379 |
+
weight, dlogdet = self.weight, self.dlogdet
|
380 |
+
z = F.conv1d(x, weight)
|
381 |
+
if logdet is not None:
|
382 |
+
logdet = logdet - dlogdet * x_len
|
383 |
+
return z, logdet
|
384 |
+
|
385 |
+
def store_inverse(self):
|
386 |
+
self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
|
387 |
+
|
388 |
+
|
389 |
+
class Flip(nn.Module):
|
390 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
391 |
+
x = torch.flip(x, [1])
|
392 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
393 |
+
return x, logdet
|
394 |
+
|
395 |
+
def store_inverse(self):
|
396 |
+
pass
|
397 |
+
|
398 |
+
|
399 |
+
class CouplingBlock(nn.Module):
|
400 |
+
def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
401 |
+
gin_channels=0, p_dropout=0, sigmoid_scale=False,
|
402 |
+
share_cond_layers=False, wn=None):
|
403 |
+
super().__init__()
|
404 |
+
self.in_channels = in_channels
|
405 |
+
self.hidden_channels = hidden_channels
|
406 |
+
self.kernel_size = kernel_size
|
407 |
+
self.dilation_rate = dilation_rate
|
408 |
+
self.n_layers = n_layers
|
409 |
+
self.gin_channels = gin_channels
|
410 |
+
self.p_dropout = p_dropout
|
411 |
+
self.sigmoid_scale = sigmoid_scale
|
412 |
+
|
413 |
+
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
414 |
+
start = torch.nn.utils.weight_norm(start)
|
415 |
+
self.start = start
|
416 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
417 |
+
# do nothing at first. This helps with training stability
|
418 |
+
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
419 |
+
end.weight.data.zero_()
|
420 |
+
end.bias.data.zero_()
|
421 |
+
self.end = end
|
422 |
+
self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels,
|
423 |
+
p_dropout, share_cond_layers)
|
424 |
+
if wn is not None:
|
425 |
+
self.wn.in_layers = wn.in_layers
|
426 |
+
self.wn.res_skip_layers = wn.res_skip_layers
|
427 |
+
|
428 |
+
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
|
429 |
+
if x_mask is None:
|
430 |
+
x_mask = 1
|
431 |
+
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
432 |
+
|
433 |
+
x = self.start(x_0) * x_mask
|
434 |
+
x = self.wn(x, x_mask, g)
|
435 |
+
out = self.end(x)
|
436 |
+
|
437 |
+
z_0 = x_0
|
438 |
+
m = out[:, :self.in_channels // 2, :]
|
439 |
+
logs = out[:, self.in_channels // 2:, :]
|
440 |
+
if self.sigmoid_scale:
|
441 |
+
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
442 |
+
if reverse:
|
443 |
+
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
444 |
+
logdet = torch.sum(-logs * x_mask, [1, 2])
|
445 |
+
else:
|
446 |
+
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
447 |
+
logdet = torch.sum(logs * x_mask, [1, 2])
|
448 |
+
z = torch.cat([z_0, z_1], 1)
|
449 |
+
return z, logdet
|
450 |
+
|
451 |
+
def store_inverse(self):
|
452 |
+
self.wn.remove_weight_norm()
|
453 |
+
|
454 |
+
|
455 |
+
class GlowFFTBlocks(FFTBlocks):
|
456 |
+
def __init__(self, hidden_size=128, gin_channels=256, num_layers=2, ffn_kernel_size=5,
|
457 |
+
dropout=None, num_heads=4, use_pos_embed=True, use_last_norm=True,
|
458 |
+
norm='ln', use_pos_embed_alpha=True):
|
459 |
+
super().__init__(hidden_size, num_layers, ffn_kernel_size, dropout, num_heads, use_pos_embed,
|
460 |
+
use_last_norm, norm, use_pos_embed_alpha)
|
461 |
+
self.inp_proj = nn.Conv1d(hidden_size + gin_channels, hidden_size, 1)
|
462 |
+
|
463 |
+
def forward(self, x, x_mask=None, g=None):
|
464 |
+
"""
|
465 |
+
:param x: [B, C_x, T]
|
466 |
+
:param x_mask: [B, 1, T]
|
467 |
+
:param g: [B, C_g, T]
|
468 |
+
:return: [B, C_x, T]
|
469 |
+
"""
|
470 |
+
if g is not None:
|
471 |
+
x = self.inp_proj(torch.cat([x, g], 1))
|
472 |
+
x = x.transpose(1, 2)
|
473 |
+
x = super(GlowFFTBlocks, self).forward(x, x_mask[:, 0] == 0)
|
474 |
+
x = x.transpose(1, 2)
|
475 |
+
return x
|
476 |
+
|
477 |
+
|
478 |
+
class TransformerCouplingBlock(nn.Module):
|
479 |
+
def __init__(self, in_channels, hidden_channels, n_layers,
|
480 |
+
gin_channels=0, p_dropout=0, sigmoid_scale=False):
|
481 |
+
super().__init__()
|
482 |
+
self.in_channels = in_channels
|
483 |
+
self.hidden_channels = hidden_channels
|
484 |
+
self.n_layers = n_layers
|
485 |
+
self.gin_channels = gin_channels
|
486 |
+
self.p_dropout = p_dropout
|
487 |
+
self.sigmoid_scale = sigmoid_scale
|
488 |
+
|
489 |
+
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
490 |
+
self.start = start
|
491 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
492 |
+
# do nothing at first. This helps with training stability
|
493 |
+
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
494 |
+
end.weight.data.zero_()
|
495 |
+
end.bias.data.zero_()
|
496 |
+
self.end = end
|
497 |
+
self.fft_blocks = GlowFFTBlocks(
|
498 |
+
hidden_size=hidden_channels,
|
499 |
+
ffn_kernel_size=3,
|
500 |
+
gin_channels=gin_channels,
|
501 |
+
num_layers=n_layers)
|
502 |
+
|
503 |
+
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
|
504 |
+
if x_mask is None:
|
505 |
+
x_mask = 1
|
506 |
+
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
507 |
+
|
508 |
+
x = self.start(x_0) * x_mask
|
509 |
+
x = self.fft_blocks(x, x_mask, g)
|
510 |
+
out = self.end(x)
|
511 |
+
|
512 |
+
z_0 = x_0
|
513 |
+
m = out[:, :self.in_channels // 2, :]
|
514 |
+
logs = out[:, self.in_channels // 2:, :]
|
515 |
+
if self.sigmoid_scale:
|
516 |
+
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
517 |
+
if reverse:
|
518 |
+
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
519 |
+
logdet = torch.sum(-logs * x_mask, [1, 2])
|
520 |
+
else:
|
521 |
+
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
522 |
+
logdet = torch.sum(logs * x_mask, [1, 2])
|
523 |
+
z = torch.cat([z_0, z_1], 1)
|
524 |
+
return z, logdet
|
525 |
+
|
526 |
+
def store_inverse(self):
|
527 |
+
pass
|
528 |
+
|
529 |
+
|
530 |
+
class FreqFFTCouplingBlock(nn.Module):
|
531 |
+
def __init__(self, in_channels, hidden_channels, n_layers,
|
532 |
+
gin_channels=0, p_dropout=0, sigmoid_scale=False):
|
533 |
+
super().__init__()
|
534 |
+
self.in_channels = in_channels
|
535 |
+
self.hidden_channels = hidden_channels
|
536 |
+
self.n_layers = n_layers
|
537 |
+
self.gin_channels = gin_channels
|
538 |
+
self.p_dropout = p_dropout
|
539 |
+
self.sigmoid_scale = sigmoid_scale
|
540 |
+
|
541 |
+
hs = hidden_channels
|
542 |
+
stride = 8
|
543 |
+
self.start = torch.nn.Conv2d(3, hs, kernel_size=stride * 2,
|
544 |
+
stride=stride, padding=stride // 2)
|
545 |
+
end = nn.ConvTranspose2d(hs, 2, kernel_size=stride, stride=stride)
|
546 |
+
end.weight.data.zero_()
|
547 |
+
end.bias.data.zero_()
|
548 |
+
self.end = nn.Sequential(
|
549 |
+
nn.Conv2d(hs * 3, hs, 3, 1, 1),
|
550 |
+
nn.ReLU(),
|
551 |
+
nn.GroupNorm(4, hs),
|
552 |
+
nn.Conv2d(hs, hs, 3, 1, 1),
|
553 |
+
end
|
554 |
+
)
|
555 |
+
self.fft_v = FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers)
|
556 |
+
self.fft_h = nn.Sequential(
|
557 |
+
nn.Conv1d(hs, hs, 3, 1, 1),
|
558 |
+
nn.ReLU(),
|
559 |
+
nn.Conv1d(hs, hs, 3, 1, 1),
|
560 |
+
)
|
561 |
+
self.fft_g = nn.Sequential(
|
562 |
+
nn.Conv1d(
|
563 |
+
gin_channels - 160, hs, kernel_size=stride * 2, stride=stride, padding=stride // 2),
|
564 |
+
Permute(0, 2, 1),
|
565 |
+
FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers),
|
566 |
+
Permute(0, 2, 1),
|
567 |
+
)
|
568 |
+
|
569 |
+
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
|
570 |
+
g_, _ = utils.unsqueeze(g)
|
571 |
+
g_mel = g_[:, :80]
|
572 |
+
g_txt = g_[:, 80:]
|
573 |
+
g_mel, _ = utils.squeeze(g_mel)
|
574 |
+
g_txt, _ = utils.squeeze(g_txt) # [B, C, T]
|
575 |
+
|
576 |
+
if x_mask is None:
|
577 |
+
x_mask = 1
|
578 |
+
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
579 |
+
x = torch.stack([x_0, g_mel[:, :80], g_mel[:, 80:]], 1)
|
580 |
+
x = self.start(x) # [B, C, N_bins, T]
|
581 |
+
B, C, N_bins, T = x.shape
|
582 |
+
|
583 |
+
x_v = self.fft_v(x.permute(0, 3, 2, 1).reshape(B * T, N_bins, C))
|
584 |
+
x_v = x_v.reshape(B, T, N_bins, -1).permute(0, 3, 2, 1)
|
585 |
+
# x_v = x
|
586 |
+
|
587 |
+
x_h = self.fft_h(x.permute(0, 2, 1, 3).reshape(B * N_bins, C, T))
|
588 |
+
x_h = x_h.reshape(B, N_bins, -1, T).permute(0, 2, 1, 3)
|
589 |
+
# x_h = x
|
590 |
+
|
591 |
+
x_g = self.fft_g(g_txt)[:, :, None, :].repeat(1, 1, 10, 1)
|
592 |
+
x = torch.cat([x_v, x_h, x_g], 1)
|
593 |
+
out = self.end(x)
|
594 |
+
|
595 |
+
z_0 = x_0
|
596 |
+
m = out[:, 0]
|
597 |
+
logs = out[:, 1]
|
598 |
+
if self.sigmoid_scale:
|
599 |
+
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
600 |
+
if reverse:
|
601 |
+
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
602 |
+
logdet = torch.sum(-logs * x_mask, [1, 2])
|
603 |
+
else:
|
604 |
+
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
605 |
+
logdet = torch.sum(logs * x_mask, [1, 2])
|
606 |
+
z = torch.cat([z_0, z_1], 1)
|
607 |
+
return z, logdet
|
608 |
+
|
609 |
+
def store_inverse(self):
|
610 |
+
pass
|
611 |
+
|
612 |
+
|
613 |
+
|
614 |
+
class ResidualCouplingLayer(nn.Module):
|
615 |
+
def __init__(self,
|
616 |
+
channels,
|
617 |
+
hidden_channels,
|
618 |
+
kernel_size,
|
619 |
+
dilation_rate,
|
620 |
+
n_layers,
|
621 |
+
p_dropout=0,
|
622 |
+
gin_channels=0,
|
623 |
+
mean_only=False,
|
624 |
+
nn_type='wn'):
|
625 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
626 |
+
super().__init__()
|
627 |
+
self.channels = channels
|
628 |
+
self.hidden_channels = hidden_channels
|
629 |
+
self.kernel_size = kernel_size
|
630 |
+
self.dilation_rate = dilation_rate
|
631 |
+
self.n_layers = n_layers
|
632 |
+
self.half_channels = channels // 2
|
633 |
+
self.mean_only = mean_only
|
634 |
+
|
635 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
636 |
+
if nn_type == 'wn':
|
637 |
+
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
|
638 |
+
gin_channels=gin_channels)
|
639 |
+
# elif nn_type == 'conv':
|
640 |
+
# self.enc = ConditionalConvBlocks(
|
641 |
+
# hidden_channels, gin_channels, hidden_channels, [1] * n_layers, kernel_size,
|
642 |
+
# layers_in_block=1, is_BTC=False)
|
643 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
644 |
+
self.post.weight.data.zero_()
|
645 |
+
self.post.bias.data.zero_()
|
646 |
+
|
647 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
648 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
649 |
+
h = self.pre(x0) * x_mask
|
650 |
+
h = self.enc(h, x_mask=x_mask, g=g)
|
651 |
+
stats = self.post(h) * x_mask
|
652 |
+
if not self.mean_only:
|
653 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
654 |
+
else:
|
655 |
+
m = stats
|
656 |
+
logs = torch.zeros_like(m)
|
657 |
+
|
658 |
+
if not reverse:
|
659 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
660 |
+
x = torch.cat([x0, x1], 1)
|
661 |
+
logdet = torch.sum(logs, [1, 2])
|
662 |
+
return x, logdet
|
663 |
+
else:
|
664 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
665 |
+
x = torch.cat([x0, x1], 1)
|
666 |
+
logdet = -torch.sum(logs, [1, 2])
|
667 |
+
return x, logdet
|
668 |
+
|
669 |
+
|
670 |
+
class ResidualCouplingBlock(nn.Module):
|
671 |
+
def __init__(self,
|
672 |
+
channels,
|
673 |
+
hidden_channels,
|
674 |
+
kernel_size,
|
675 |
+
dilation_rate,
|
676 |
+
n_layers,
|
677 |
+
n_flows=4,
|
678 |
+
gin_channels=0,
|
679 |
+
nn_type='wn'):
|
680 |
+
super().__init__()
|
681 |
+
self.channels = channels
|
682 |
+
self.hidden_channels = hidden_channels
|
683 |
+
self.kernel_size = kernel_size
|
684 |
+
self.dilation_rate = dilation_rate
|
685 |
+
self.n_layers = n_layers
|
686 |
+
self.n_flows = n_flows
|
687 |
+
self.gin_channels = gin_channels
|
688 |
+
|
689 |
+
self.flows = nn.ModuleList()
|
690 |
+
for i in range(n_flows):
|
691 |
+
self.flows.append(
|
692 |
+
ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
|
693 |
+
gin_channels=gin_channels, mean_only=True, nn_type=nn_type))
|
694 |
+
self.flows.append(Flip())
|
695 |
+
|
696 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
697 |
+
if not reverse:
|
698 |
+
for flow in self.flows:
|
699 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
700 |
+
else:
|
701 |
+
for flow in reversed(self.flows):
|
702 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
703 |
+
return x
|
704 |
+
|
705 |
+
|
706 |
+
class Glow(nn.Module):
|
707 |
+
def __init__(self,
|
708 |
+
in_channels,
|
709 |
+
hidden_channels,
|
710 |
+
kernel_size,
|
711 |
+
dilation_rate,
|
712 |
+
n_blocks,
|
713 |
+
n_layers,
|
714 |
+
p_dropout=0.,
|
715 |
+
n_split=4,
|
716 |
+
n_sqz=2,
|
717 |
+
sigmoid_scale=False,
|
718 |
+
gin_channels=0,
|
719 |
+
inv_conv_type='near',
|
720 |
+
share_cond_layers=False,
|
721 |
+
share_wn_layers=0,
|
722 |
+
):
|
723 |
+
super().__init__()
|
724 |
+
"""
|
725 |
+
Note that regularization likes weight decay can leads to Nan error!
|
726 |
+
"""
|
727 |
+
|
728 |
+
self.in_channels = in_channels
|
729 |
+
self.hidden_channels = hidden_channels
|
730 |
+
self.kernel_size = kernel_size
|
731 |
+
self.dilation_rate = dilation_rate
|
732 |
+
self.n_blocks = n_blocks
|
733 |
+
self.n_layers = n_layers
|
734 |
+
self.p_dropout = p_dropout
|
735 |
+
self.n_split = n_split
|
736 |
+
self.n_sqz = n_sqz
|
737 |
+
self.sigmoid_scale = sigmoid_scale
|
738 |
+
self.gin_channels = gin_channels
|
739 |
+
self.share_cond_layers = share_cond_layers
|
740 |
+
if gin_channels != 0 and share_cond_layers:
|
741 |
+
cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
|
742 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
743 |
+
wn = None
|
744 |
+
self.flows = nn.ModuleList()
|
745 |
+
for b in range(n_blocks):
|
746 |
+
self.flows.append(ActNorm(channels=in_channels * n_sqz))
|
747 |
+
if inv_conv_type == 'near':
|
748 |
+
self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
|
749 |
+
if inv_conv_type == 'invconv':
|
750 |
+
self.flows.append(InvConv(channels=in_channels * n_sqz))
|
751 |
+
if share_wn_layers > 0:
|
752 |
+
if b % share_wn_layers == 0:
|
753 |
+
wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
|
754 |
+
p_dropout, share_cond_layers)
|
755 |
+
self.flows.append(
|
756 |
+
CouplingBlock(
|
757 |
+
in_channels * n_sqz,
|
758 |
+
hidden_channels,
|
759 |
+
kernel_size=kernel_size,
|
760 |
+
dilation_rate=dilation_rate,
|
761 |
+
n_layers=n_layers,
|
762 |
+
gin_channels=gin_channels * n_sqz,
|
763 |
+
p_dropout=p_dropout,
|
764 |
+
sigmoid_scale=sigmoid_scale,
|
765 |
+
share_cond_layers=share_cond_layers,
|
766 |
+
wn=wn
|
767 |
+
))
|
768 |
+
|
769 |
+
def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
|
770 |
+
"""
|
771 |
+
x: [B,T,C]
|
772 |
+
x_mask: [B,T]
|
773 |
+
g: [B,T,C]
|
774 |
+
"""
|
775 |
+
x = x.transpose(1,2)
|
776 |
+
x_mask = x_mask.unsqueeze(1)
|
777 |
+
if g is not None:
|
778 |
+
g = g.transpose(1,2)
|
779 |
+
|
780 |
+
logdet_tot = 0
|
781 |
+
if not reverse:
|
782 |
+
flows = self.flows
|
783 |
+
else:
|
784 |
+
flows = reversed(self.flows)
|
785 |
+
if return_hiddens:
|
786 |
+
hs = []
|
787 |
+
if self.n_sqz > 1:
|
788 |
+
x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
|
789 |
+
if g is not None:
|
790 |
+
g, _ = utils.squeeze(g, x_mask, self.n_sqz)
|
791 |
+
x_mask = x_mask_
|
792 |
+
if self.share_cond_layers and g is not None:
|
793 |
+
g = self.cond_layer(g)
|
794 |
+
for f in flows:
|
795 |
+
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
796 |
+
if return_hiddens:
|
797 |
+
hs.append(x)
|
798 |
+
logdet_tot += logdet
|
799 |
+
if self.n_sqz > 1:
|
800 |
+
x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
|
801 |
+
|
802 |
+
x = x.transpose(1,2)
|
803 |
+
if return_hiddens:
|
804 |
+
return x, logdet_tot, hs
|
805 |
+
return x, logdet_tot
|
806 |
+
|
807 |
+
def store_inverse(self):
|
808 |
+
def remove_weight_norm(m):
|
809 |
+
try:
|
810 |
+
nn.utils.remove_weight_norm(m)
|
811 |
+
except ValueError: # this module didn't have weight norm
|
812 |
+
return
|
813 |
+
|
814 |
+
self.apply(remove_weight_norm)
|
815 |
+
for f in self.flows:
|
816 |
+
f.store_inverse()
|
817 |
+
|
818 |
+
|
819 |
+
if __name__ == '__main__':
|
820 |
+
model = Glow(in_channels=64,
|
821 |
+
hidden_channels=128,
|
822 |
+
kernel_size=5,
|
823 |
+
dilation_rate=1,
|
824 |
+
n_blocks=12,
|
825 |
+
n_layers=4,
|
826 |
+
p_dropout=0.0,
|
827 |
+
n_split=4,
|
828 |
+
n_sqz=2,
|
829 |
+
sigmoid_scale=False,
|
830 |
+
gin_channels=80
|
831 |
+
)
|
832 |
+
exp = torch.rand([1,1440,64])
|
833 |
+
mel = torch.rand([1,1440,80])
|
834 |
+
x_mask = torch.ones([1,1440],dtype=torch.float32)
|
835 |
+
y, logdet = model(exp, x_mask,g=mel, reverse=False)
|
836 |
+
pred_exp, logdet = model(y, x_mask,g=mel, reverse=False)
|
837 |
+
# y: [b, t,c=64]
|
838 |
+
print(" ")
|
modules/audio2motion/multi_length_disc.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from modules.audio2motion.cnn_models import LambdaLayer
|
7 |
+
|
8 |
+
|
9 |
+
class Discriminator1DFactory(nn.Module):
|
10 |
+
def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
|
11 |
+
super(Discriminator1DFactory, self).__init__()
|
12 |
+
padding = kernel_size // 2
|
13 |
+
|
14 |
+
def discriminator_block(in_filters, out_filters, first=False):
|
15 |
+
"""
|
16 |
+
Input: (B, c, T)
|
17 |
+
Output:(B, c, T//2)
|
18 |
+
"""
|
19 |
+
conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
|
20 |
+
block = [
|
21 |
+
conv, # padding = kernel//2
|
22 |
+
nn.LeakyReLU(0.2, inplace=True),
|
23 |
+
nn.Dropout2d(0.25)
|
24 |
+
]
|
25 |
+
if norm_type == 'bn' and not first:
|
26 |
+
block.append(nn.BatchNorm1d(out_filters, 0.8))
|
27 |
+
if norm_type == 'in' and not first:
|
28 |
+
block.append(nn.InstanceNorm1d(out_filters, affine=True))
|
29 |
+
block = nn.Sequential(*block)
|
30 |
+
return block
|
31 |
+
|
32 |
+
if time_length >= 8:
|
33 |
+
self.model = nn.ModuleList([
|
34 |
+
discriminator_block(in_dim, hidden_size, first=True),
|
35 |
+
discriminator_block(hidden_size, hidden_size),
|
36 |
+
discriminator_block(hidden_size, hidden_size),
|
37 |
+
])
|
38 |
+
ds_size = time_length // (2 ** 3)
|
39 |
+
elif time_length == 3:
|
40 |
+
self.model = nn.ModuleList([
|
41 |
+
nn.Sequential(*[
|
42 |
+
nn.Conv1d(in_dim, hidden_size, 3, 1, 0),
|
43 |
+
nn.LeakyReLU(0.2, inplace=True),
|
44 |
+
nn.Dropout2d(0.25),
|
45 |
+
nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
|
46 |
+
nn.LeakyReLU(0.2, inplace=True),
|
47 |
+
nn.Dropout2d(0.25),
|
48 |
+
nn.BatchNorm1d(hidden_size, 0.8),
|
49 |
+
nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
|
50 |
+
nn.LeakyReLU(0.2, inplace=True),
|
51 |
+
nn.Dropout2d(0.25),
|
52 |
+
nn.BatchNorm1d(hidden_size, 0.8)
|
53 |
+
])
|
54 |
+
])
|
55 |
+
ds_size = 1
|
56 |
+
elif time_length == 1:
|
57 |
+
self.model = nn.ModuleList([
|
58 |
+
nn.Sequential(*[
|
59 |
+
nn.Linear(in_dim, hidden_size),
|
60 |
+
nn.LeakyReLU(0.2, inplace=True),
|
61 |
+
nn.Dropout2d(0.25),
|
62 |
+
nn.Linear(hidden_size, hidden_size),
|
63 |
+
nn.LeakyReLU(0.2, inplace=True),
|
64 |
+
nn.Dropout2d(0.25),
|
65 |
+
])
|
66 |
+
])
|
67 |
+
ds_size = 1
|
68 |
+
|
69 |
+
self.adv_layer = nn.Linear(hidden_size * ds_size, 1)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""
|
73 |
+
|
74 |
+
:param x: [B, C, T]
|
75 |
+
:return: validity: [B, 1], h: List of hiddens
|
76 |
+
"""
|
77 |
+
h = []
|
78 |
+
if x.shape[-1] == 1:
|
79 |
+
x = x.squeeze(-1)
|
80 |
+
for l in self.model:
|
81 |
+
x = l(x)
|
82 |
+
h.append(x)
|
83 |
+
if x.ndim == 2:
|
84 |
+
b, ct = x.shape
|
85 |
+
use_sigmoid = True
|
86 |
+
else:
|
87 |
+
b, c, t = x.shape
|
88 |
+
ct = c * t
|
89 |
+
use_sigmoid = False
|
90 |
+
x = x.view(b, ct)
|
91 |
+
validity = self.adv_layer(x) # [B, 1]
|
92 |
+
if use_sigmoid:
|
93 |
+
validity = torch.sigmoid(validity)
|
94 |
+
return validity, h
|
95 |
+
|
96 |
+
|
97 |
+
class CosineDiscriminator1DFactory(nn.Module):
|
98 |
+
def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
|
99 |
+
super().__init__()
|
100 |
+
padding = kernel_size // 2
|
101 |
+
|
102 |
+
def discriminator_block(in_filters, out_filters, first=False):
|
103 |
+
"""
|
104 |
+
Input: (B, c, T)
|
105 |
+
Output:(B, c, T//2)
|
106 |
+
"""
|
107 |
+
conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
|
108 |
+
block = [
|
109 |
+
conv, # padding = kernel//2
|
110 |
+
nn.LeakyReLU(0.2, inplace=True),
|
111 |
+
nn.Dropout2d(0.25)
|
112 |
+
]
|
113 |
+
if norm_type == 'bn' and not first:
|
114 |
+
block.append(nn.BatchNorm1d(out_filters, 0.8))
|
115 |
+
if norm_type == 'in' and not first:
|
116 |
+
block.append(nn.InstanceNorm1d(out_filters, affine=True))
|
117 |
+
block = nn.Sequential(*block)
|
118 |
+
return block
|
119 |
+
|
120 |
+
self.model1 = nn.ModuleList([
|
121 |
+
discriminator_block(in_dim, hidden_size, first=True),
|
122 |
+
discriminator_block(hidden_size, hidden_size),
|
123 |
+
discriminator_block(hidden_size, hidden_size),
|
124 |
+
])
|
125 |
+
|
126 |
+
self.model2 = nn.ModuleList([
|
127 |
+
discriminator_block(in_dim, hidden_size, first=True),
|
128 |
+
discriminator_block(hidden_size, hidden_size),
|
129 |
+
discriminator_block(hidden_size, hidden_size),
|
130 |
+
])
|
131 |
+
|
132 |
+
self.relu = nn.ReLU()
|
133 |
+
def forward(self, x1, x2):
|
134 |
+
"""
|
135 |
+
|
136 |
+
:param x1: [B, C, T]
|
137 |
+
:param x2: [B, C, T]
|
138 |
+
:return: validity: [B, 1], h: List of hiddens
|
139 |
+
"""
|
140 |
+
h1, h2 = [], []
|
141 |
+
for l in self.model1:
|
142 |
+
x1 = l(x1)
|
143 |
+
h1.append(x1)
|
144 |
+
for l in self.model2:
|
145 |
+
x2 = l(x2)
|
146 |
+
h2.append(x1)
|
147 |
+
b,c,t = x1.shape
|
148 |
+
x1 = x1.view(b, c*t)
|
149 |
+
x2 = x2.view(b, c*t)
|
150 |
+
x1 = self.relu(x1)
|
151 |
+
x2 = self.relu(x2)
|
152 |
+
# x1 = F.normalize(x1, p=2, dim=1)
|
153 |
+
# x2 = F.normalize(x2, p=2, dim=1)
|
154 |
+
validity = F.cosine_similarity(x1, x2)
|
155 |
+
return validity, [h1,h2]
|
156 |
+
|
157 |
+
|
158 |
+
class MultiWindowDiscriminator(nn.Module):
|
159 |
+
def __init__(self, time_lengths, cond_dim=80, in_dim=64, kernel_size=3, hidden_size=128, disc_type='standard', norm_type='bn', reduction='sum'):
|
160 |
+
super(MultiWindowDiscriminator, self).__init__()
|
161 |
+
self.win_lengths = time_lengths
|
162 |
+
self.reduction = reduction
|
163 |
+
self.disc_type = disc_type
|
164 |
+
|
165 |
+
if cond_dim > 0:
|
166 |
+
self.use_cond = True
|
167 |
+
self.cond_proj_layers = nn.ModuleList()
|
168 |
+
self.in_proj_layers = nn.ModuleList()
|
169 |
+
else:
|
170 |
+
self.use_cond = False
|
171 |
+
|
172 |
+
self.conv_layers = nn.ModuleList()
|
173 |
+
for time_length in time_lengths:
|
174 |
+
conv_layer = [
|
175 |
+
Discriminator1DFactory(
|
176 |
+
time_length, kernel_size, in_dim=64, hidden_size=hidden_size,
|
177 |
+
norm_type=norm_type) if self.disc_type == 'standard'
|
178 |
+
else CosineDiscriminator1DFactory(time_length, kernel_size, in_dim=64,
|
179 |
+
hidden_size=hidden_size,norm_type=norm_type)
|
180 |
+
]
|
181 |
+
self.conv_layers += conv_layer
|
182 |
+
if self.use_cond:
|
183 |
+
self.cond_proj_layers.append(nn.Linear(cond_dim, 64))
|
184 |
+
self.in_proj_layers.append(nn.Linear(in_dim, 64))
|
185 |
+
|
186 |
+
def clip(self, x, cond, x_len, win_length, start_frames=None):
|
187 |
+
'''Ramdom clip x to win_length.
|
188 |
+
Args:
|
189 |
+
x (tensor) : (B, T, C).
|
190 |
+
cond (tensor) : (B, T, H).
|
191 |
+
x_len (tensor) : (B,).
|
192 |
+
win_length (int): target clip length
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
(tensor) : (B, c_in, win_length, n_bins).
|
196 |
+
|
197 |
+
'''
|
198 |
+
clip_from_same_frame = start_frames is None
|
199 |
+
T_start = 0
|
200 |
+
# T_end = x_len.max() - win_length
|
201 |
+
T_end = x_len.min() - win_length
|
202 |
+
if T_end < 0:
|
203 |
+
return None, None, start_frames
|
204 |
+
T_end = T_end.item()
|
205 |
+
if start_frames is None:
|
206 |
+
start_frame = np.random.randint(low=T_start, high=T_end + 1)
|
207 |
+
start_frames = [start_frame] * x.size(0)
|
208 |
+
else:
|
209 |
+
start_frame = start_frames[0]
|
210 |
+
|
211 |
+
|
212 |
+
if clip_from_same_frame:
|
213 |
+
x_batch = x[:, start_frame: start_frame + win_length, :]
|
214 |
+
c_batch = cond[:, start_frame: start_frame + win_length, :] if cond is not None else None
|
215 |
+
else:
|
216 |
+
x_lst = []
|
217 |
+
c_lst = []
|
218 |
+
for i, start_frame in enumerate(start_frames):
|
219 |
+
x_lst.append(x[i, start_frame: start_frame + win_length, :])
|
220 |
+
if cond is not None:
|
221 |
+
c_lst.append(cond[i, start_frame: start_frame + win_length, :])
|
222 |
+
x_batch = torch.stack(x_lst, dim=0)
|
223 |
+
if cond is None:
|
224 |
+
c_batch = None
|
225 |
+
else:
|
226 |
+
c_batch = torch.stack(c_lst, dim=0)
|
227 |
+
return x_batch, c_batch, start_frames
|
228 |
+
|
229 |
+
def forward(self, x, x_len, cond=None, start_frames_wins=None):
|
230 |
+
'''
|
231 |
+
Args:
|
232 |
+
x (tensor): input mel, (B, T, C).
|
233 |
+
x_length (tensor): len of per mel. (B,).
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
tensor : (B).
|
237 |
+
'''
|
238 |
+
validity = []
|
239 |
+
if start_frames_wins is None:
|
240 |
+
start_frames_wins = [None] * len(self.conv_layers)
|
241 |
+
h = []
|
242 |
+
for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins):
|
243 |
+
x_clip, c_clip, start_frames = self.clip(
|
244 |
+
x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C)
|
245 |
+
start_frames_wins[i] = start_frames
|
246 |
+
if x_clip is None:
|
247 |
+
continue
|
248 |
+
if self.disc_type == 'standard':
|
249 |
+
if self.use_cond:
|
250 |
+
x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
|
251 |
+
c_clip = self.cond_proj_layers[i](c_clip)
|
252 |
+
x_clip = x_clip + c_clip
|
253 |
+
validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2))
|
254 |
+
elif self.disc_type == 'cosine':
|
255 |
+
assert self.use_cond is True
|
256 |
+
x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
|
257 |
+
c_clip = self.cond_proj_layers[i](c_clip)
|
258 |
+
validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2), c_clip.transpose(1,2))
|
259 |
+
else:
|
260 |
+
raise NotImplementedError
|
261 |
+
|
262 |
+
h += h_
|
263 |
+
validity.append(validity_pred)
|
264 |
+
if len(validity) != len(self.conv_layers):
|
265 |
+
return None, start_frames_wins, h
|
266 |
+
if self.reduction == 'sum':
|
267 |
+
validity = sum(validity) # [B]
|
268 |
+
elif self.reduction == 'stack':
|
269 |
+
validity = torch.stack(validity, -1) # [B, W_L]
|
270 |
+
return validity, start_frames_wins, h
|
271 |
+
|
272 |
+
|
273 |
+
class Discriminator(nn.Module):
|
274 |
+
def __init__(self, x_dim=80, y_dim=64, disc_type='standard',
|
275 |
+
uncond_disc=False, kernel_size=3, hidden_size=128, norm_type='bn', reduction='sum', time_lengths=(8,16,32)):
|
276 |
+
"""_summary_
|
277 |
+
|
278 |
+
Args:
|
279 |
+
time_lengths (list, optional): the list of window size. Defaults to [32, 64, 128].
|
280 |
+
x_dim (int, optional): the dim of audio features. Defaults to 80, corresponding to mel-spec.
|
281 |
+
y_dim (int, optional): the dim of facial coeff. Defaults to 64, correspond to exp; other options can be 7(pose) or 71(exp+pose).
|
282 |
+
kernel (tuple, optional): _description_. Defaults to (3, 3).
|
283 |
+
c_in (int, optional): _description_. Defaults to 1.
|
284 |
+
hidden_size (int, optional): _description_. Defaults to 128.
|
285 |
+
norm_type (str, optional): _description_. Defaults to 'bn'.
|
286 |
+
reduction (str, optional): _description_. Defaults to 'sum'.
|
287 |
+
uncond_disc (bool, optional): _description_. Defaults to False.
|
288 |
+
"""
|
289 |
+
super(Discriminator, self).__init__()
|
290 |
+
self.time_lengths = time_lengths
|
291 |
+
self.x_dim, self.y_dim = x_dim, y_dim
|
292 |
+
self.disc_type = disc_type
|
293 |
+
self.reduction = reduction
|
294 |
+
self.uncond_disc = uncond_disc
|
295 |
+
|
296 |
+
if uncond_disc:
|
297 |
+
self.x_dim = 0
|
298 |
+
cond_dim = 0
|
299 |
+
|
300 |
+
else:
|
301 |
+
cond_dim = 64
|
302 |
+
self.mel_encoder = nn.Sequential(*[
|
303 |
+
nn.Conv1d(self.x_dim, 64, 3, 1, 1, bias=False),
|
304 |
+
nn.BatchNorm1d(64),
|
305 |
+
nn.GELU(),
|
306 |
+
nn.Conv1d(64, cond_dim, 3, 1, 1, bias=False)
|
307 |
+
])
|
308 |
+
|
309 |
+
self.disc = MultiWindowDiscriminator(
|
310 |
+
time_lengths=self.time_lengths,
|
311 |
+
in_dim=self.y_dim,
|
312 |
+
cond_dim=cond_dim,
|
313 |
+
kernel_size=kernel_size,
|
314 |
+
hidden_size=hidden_size, norm_type=norm_type,
|
315 |
+
reduction=reduction,
|
316 |
+
disc_type=disc_type
|
317 |
+
)
|
318 |
+
self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
|
319 |
+
|
320 |
+
@property
|
321 |
+
def device(self):
|
322 |
+
return self.disc.parameters().__next__().device
|
323 |
+
|
324 |
+
def forward(self,x, batch, start_frames_wins=None):
|
325 |
+
"""
|
326 |
+
|
327 |
+
:param x: [B, T, C]
|
328 |
+
:param cond: [B, T, cond_size]
|
329 |
+
:return:
|
330 |
+
"""
|
331 |
+
x = x.to(self.device)
|
332 |
+
if not self.uncond_disc:
|
333 |
+
mel = self.downsampler(batch['mel'].to(self.device))
|
334 |
+
mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
|
335 |
+
else:
|
336 |
+
mel_feat = None
|
337 |
+
x_len = x.sum(-1).ne(0).int().sum([1])
|
338 |
+
disc_confidence, start_frames_wins, h = self.disc(x, x_len, mel_feat, start_frames_wins=start_frames_wins)
|
339 |
+
return disc_confidence
|
340 |
+
|
modules/audio2motion/transformer_base.py
ADDED
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import Parameter
|
5 |
+
import torch.onnx.operators
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
|
10 |
+
def make_positions(tensor, padding_idx):
|
11 |
+
"""Replace non-padding symbols with their position numbers.
|
12 |
+
|
13 |
+
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
14 |
+
"""
|
15 |
+
# The series of casts and type-conversions here are carefully
|
16 |
+
# balanced to both work with ONNX export and XLA. In particular XLA
|
17 |
+
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
18 |
+
# how to handle the dtype kwarg in cumsum.
|
19 |
+
mask = tensor.ne(padding_idx).int()
|
20 |
+
return (
|
21 |
+
torch.cumsum(mask, dim=1).type_as(mask) * mask
|
22 |
+
).long() + padding_idx
|
23 |
+
|
24 |
+
|
25 |
+
def softmax(x, dim):
|
26 |
+
return F.softmax(x, dim=dim, dtype=torch.float32)
|
27 |
+
|
28 |
+
|
29 |
+
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
|
30 |
+
|
31 |
+
def _get_full_incremental_state_key(module_instance, key):
|
32 |
+
module_name = module_instance.__class__.__name__
|
33 |
+
|
34 |
+
# assign a unique ID to each module instance, so that incremental state is
|
35 |
+
# not shared across module instances
|
36 |
+
if not hasattr(module_instance, '_instance_id'):
|
37 |
+
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
|
38 |
+
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
|
39 |
+
|
40 |
+
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def get_incremental_state(module, incremental_state, key):
|
45 |
+
"""Helper for getting incremental state for an nn.Module."""
|
46 |
+
full_key = _get_full_incremental_state_key(module, key)
|
47 |
+
if incremental_state is None or full_key not in incremental_state:
|
48 |
+
return None
|
49 |
+
return incremental_state[full_key]
|
50 |
+
|
51 |
+
|
52 |
+
def set_incremental_state(module, incremental_state, key, value):
|
53 |
+
"""Helper for setting incremental state for an nn.Module."""
|
54 |
+
if incremental_state is not None:
|
55 |
+
full_key = _get_full_incremental_state_key(module, key)
|
56 |
+
incremental_state[full_key] = value
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class Reshape(nn.Module):
|
61 |
+
def __init__(self, *args):
|
62 |
+
super(Reshape, self).__init__()
|
63 |
+
self.shape = args
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
return x.view(self.shape)
|
67 |
+
|
68 |
+
|
69 |
+
class Permute(nn.Module):
|
70 |
+
def __init__(self, *args):
|
71 |
+
super(Permute, self).__init__()
|
72 |
+
self.args = args
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
return x.permute(self.args)
|
76 |
+
|
77 |
+
|
78 |
+
class LinearNorm(torch.nn.Module):
|
79 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
80 |
+
super(LinearNorm, self).__init__()
|
81 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
82 |
+
|
83 |
+
torch.nn.init.xavier_uniform_(
|
84 |
+
self.linear_layer.weight,
|
85 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
return self.linear_layer(x)
|
89 |
+
|
90 |
+
|
91 |
+
class ConvNorm(torch.nn.Module):
|
92 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
93 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
94 |
+
super(ConvNorm, self).__init__()
|
95 |
+
if padding is None:
|
96 |
+
assert (kernel_size % 2 == 1)
|
97 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
98 |
+
|
99 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
100 |
+
kernel_size=kernel_size, stride=stride,
|
101 |
+
padding=padding, dilation=dilation,
|
102 |
+
bias=bias)
|
103 |
+
|
104 |
+
torch.nn.init.xavier_uniform_(
|
105 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
106 |
+
|
107 |
+
def forward(self, signal):
|
108 |
+
conv_signal = self.conv(signal)
|
109 |
+
return conv_signal
|
110 |
+
|
111 |
+
|
112 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
113 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
114 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
115 |
+
if padding_idx is not None:
|
116 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
117 |
+
return m
|
118 |
+
|
119 |
+
|
120 |
+
class GroupNorm1DTBC(nn.GroupNorm):
|
121 |
+
def forward(self, input):
|
122 |
+
return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1)
|
123 |
+
|
124 |
+
|
125 |
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
126 |
+
if not export and torch.cuda.is_available():
|
127 |
+
try:
|
128 |
+
from apex.normalization import FusedLayerNorm
|
129 |
+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
130 |
+
except ImportError:
|
131 |
+
pass
|
132 |
+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
133 |
+
|
134 |
+
|
135 |
+
def Linear(in_features, out_features, bias=True):
|
136 |
+
m = nn.Linear(in_features, out_features, bias)
|
137 |
+
nn.init.xavier_uniform_(m.weight)
|
138 |
+
if bias:
|
139 |
+
nn.init.constant_(m.bias, 0.)
|
140 |
+
return m
|
141 |
+
|
142 |
+
|
143 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
144 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
145 |
+
|
146 |
+
Padding symbols are ignored.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
150 |
+
super().__init__()
|
151 |
+
self.embedding_dim = embedding_dim
|
152 |
+
self.padding_idx = padding_idx
|
153 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
154 |
+
init_size,
|
155 |
+
embedding_dim,
|
156 |
+
padding_idx,
|
157 |
+
)
|
158 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
162 |
+
"""Build sinusoidal embeddings.
|
163 |
+
|
164 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
165 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
166 |
+
"""
|
167 |
+
half_dim = embedding_dim // 2
|
168 |
+
emb = math.log(10000) / (half_dim - 1)
|
169 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
170 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
171 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
172 |
+
if embedding_dim % 2 == 1:
|
173 |
+
# zero pad
|
174 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
175 |
+
if padding_idx is not None:
|
176 |
+
emb[padding_idx, :] = 0
|
177 |
+
return emb
|
178 |
+
|
179 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
180 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
181 |
+
bsz, seq_len = input.shape[:2]
|
182 |
+
max_pos = self.padding_idx + 1 + seq_len
|
183 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
184 |
+
# recompute/expand embeddings if needed
|
185 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
186 |
+
max_pos,
|
187 |
+
self.embedding_dim,
|
188 |
+
self.padding_idx,
|
189 |
+
)
|
190 |
+
self.weights = self.weights.to(self._float_tensor)
|
191 |
+
|
192 |
+
if incremental_state is not None:
|
193 |
+
# positions is the same for every token when decoding a single step
|
194 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
195 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
196 |
+
|
197 |
+
positions = make_positions(input, self.padding_idx) if positions is None else positions
|
198 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
199 |
+
|
200 |
+
def max_positions(self):
|
201 |
+
"""Maximum number of supported positions."""
|
202 |
+
return int(1e5) # an arbitrary large number
|
203 |
+
|
204 |
+
|
205 |
+
class ConvTBC(nn.Module):
|
206 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
|
207 |
+
super(ConvTBC, self).__init__()
|
208 |
+
self.in_channels = in_channels
|
209 |
+
self.out_channels = out_channels
|
210 |
+
self.kernel_size = kernel_size
|
211 |
+
self.padding = padding
|
212 |
+
|
213 |
+
self.weight = torch.nn.Parameter(torch.Tensor(
|
214 |
+
self.kernel_size, in_channels, out_channels))
|
215 |
+
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
216 |
+
|
217 |
+
def forward(self, input):
|
218 |
+
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
|
219 |
+
|
220 |
+
|
221 |
+
class MultiheadAttention(nn.Module):
|
222 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
223 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
224 |
+
encoder_decoder_attention=False):
|
225 |
+
super().__init__()
|
226 |
+
self.embed_dim = embed_dim
|
227 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
228 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
229 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
230 |
+
|
231 |
+
self.num_heads = num_heads
|
232 |
+
self.dropout = dropout
|
233 |
+
self.head_dim = embed_dim // num_heads
|
234 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
235 |
+
self.scaling = self.head_dim ** -0.5
|
236 |
+
|
237 |
+
self.self_attention = self_attention
|
238 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
239 |
+
|
240 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
241 |
+
'value to be of the same size'
|
242 |
+
|
243 |
+
if self.qkv_same_dim:
|
244 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
245 |
+
else:
|
246 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
247 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
248 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
249 |
+
|
250 |
+
if bias:
|
251 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
252 |
+
else:
|
253 |
+
self.register_parameter('in_proj_bias', None)
|
254 |
+
|
255 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
256 |
+
|
257 |
+
if add_bias_kv:
|
258 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
259 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
260 |
+
else:
|
261 |
+
self.bias_k = self.bias_v = None
|
262 |
+
|
263 |
+
self.add_zero_attn = add_zero_attn
|
264 |
+
|
265 |
+
self.reset_parameters()
|
266 |
+
|
267 |
+
self.enable_torch_version = False
|
268 |
+
if hasattr(F, "multi_head_attention_forward"):
|
269 |
+
self.enable_torch_version = True
|
270 |
+
else:
|
271 |
+
self.enable_torch_version = False
|
272 |
+
self.last_attn_probs = None
|
273 |
+
|
274 |
+
def reset_parameters(self):
|
275 |
+
if self.qkv_same_dim:
|
276 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
277 |
+
else:
|
278 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
279 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
280 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
281 |
+
|
282 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
283 |
+
if self.in_proj_bias is not None:
|
284 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
285 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
286 |
+
if self.bias_k is not None:
|
287 |
+
nn.init.xavier_normal_(self.bias_k)
|
288 |
+
if self.bias_v is not None:
|
289 |
+
nn.init.xavier_normal_(self.bias_v)
|
290 |
+
|
291 |
+
def forward(
|
292 |
+
self,
|
293 |
+
query, key, value,
|
294 |
+
key_padding_mask=None,
|
295 |
+
incremental_state=None,
|
296 |
+
need_weights=True,
|
297 |
+
static_kv=False,
|
298 |
+
attn_mask=None,
|
299 |
+
before_softmax=False,
|
300 |
+
need_head_weights=False,
|
301 |
+
enc_dec_attn_constraint_mask=None,
|
302 |
+
reset_attn_weight=None
|
303 |
+
):
|
304 |
+
"""Input shape: Time x Batch x Channel
|
305 |
+
|
306 |
+
Args:
|
307 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
308 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
309 |
+
padding elements are indicated by 1s.
|
310 |
+
need_weights (bool, optional): return the attention weights,
|
311 |
+
averaged over heads (default: False).
|
312 |
+
attn_mask (ByteTensor, optional): typically used to
|
313 |
+
implement causal attention, where the mask prevents the
|
314 |
+
attention from looking forward in time (default: None).
|
315 |
+
before_softmax (bool, optional): return the raw attention
|
316 |
+
weights and values before the attention softmax.
|
317 |
+
need_head_weights (bool, optional): return the attention
|
318 |
+
weights for each head. Implies *need_weights*. Default:
|
319 |
+
return the average attention weights over all heads.
|
320 |
+
"""
|
321 |
+
if need_head_weights:
|
322 |
+
need_weights = True
|
323 |
+
|
324 |
+
tgt_len, bsz, embed_dim = query.size()
|
325 |
+
assert embed_dim == self.embed_dim
|
326 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
327 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
328 |
+
if self.qkv_same_dim:
|
329 |
+
return F.multi_head_attention_forward(query, key, value,
|
330 |
+
self.embed_dim, self.num_heads,
|
331 |
+
self.in_proj_weight,
|
332 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
333 |
+
self.add_zero_attn, self.dropout,
|
334 |
+
self.out_proj.weight, self.out_proj.bias,
|
335 |
+
self.training, key_padding_mask, need_weights,
|
336 |
+
attn_mask)
|
337 |
+
else:
|
338 |
+
return F.multi_head_attention_forward(query, key, value,
|
339 |
+
self.embed_dim, self.num_heads,
|
340 |
+
torch.empty([0]),
|
341 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
342 |
+
self.add_zero_attn, self.dropout,
|
343 |
+
self.out_proj.weight, self.out_proj.bias,
|
344 |
+
self.training, key_padding_mask, need_weights,
|
345 |
+
attn_mask, use_separate_proj_weight=True,
|
346 |
+
q_proj_weight=self.q_proj_weight,
|
347 |
+
k_proj_weight=self.k_proj_weight,
|
348 |
+
v_proj_weight=self.v_proj_weight)
|
349 |
+
|
350 |
+
if incremental_state is not None:
|
351 |
+
saved_state = self._get_input_buffer(incremental_state)
|
352 |
+
if 'prev_key' in saved_state:
|
353 |
+
# previous time steps are cached - no need to recompute
|
354 |
+
# key and value if they are static
|
355 |
+
if static_kv:
|
356 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
357 |
+
key = value = None
|
358 |
+
else:
|
359 |
+
saved_state = None
|
360 |
+
|
361 |
+
if self.self_attention:
|
362 |
+
# self-attention
|
363 |
+
q, k, v = self.in_proj_qkv(query)
|
364 |
+
elif self.encoder_decoder_attention:
|
365 |
+
# encoder-decoder attention
|
366 |
+
q = self.in_proj_q(query)
|
367 |
+
if key is None:
|
368 |
+
assert value is None
|
369 |
+
k = v = None
|
370 |
+
else:
|
371 |
+
k = self.in_proj_k(key)
|
372 |
+
v = self.in_proj_v(key)
|
373 |
+
|
374 |
+
else:
|
375 |
+
q = self.in_proj_q(query)
|
376 |
+
k = self.in_proj_k(key)
|
377 |
+
v = self.in_proj_v(value)
|
378 |
+
q *= self.scaling
|
379 |
+
|
380 |
+
if self.bias_k is not None:
|
381 |
+
assert self.bias_v is not None
|
382 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
383 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
384 |
+
if attn_mask is not None:
|
385 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
386 |
+
if key_padding_mask is not None:
|
387 |
+
key_padding_mask = torch.cat(
|
388 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
389 |
+
|
390 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
391 |
+
if k is not None:
|
392 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
393 |
+
if v is not None:
|
394 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
395 |
+
|
396 |
+
if saved_state is not None:
|
397 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
398 |
+
if 'prev_key' in saved_state:
|
399 |
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
400 |
+
if static_kv:
|
401 |
+
k = prev_key
|
402 |
+
else:
|
403 |
+
k = torch.cat((prev_key, k), dim=1)
|
404 |
+
if 'prev_value' in saved_state:
|
405 |
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
406 |
+
if static_kv:
|
407 |
+
v = prev_value
|
408 |
+
else:
|
409 |
+
v = torch.cat((prev_value, v), dim=1)
|
410 |
+
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
|
411 |
+
prev_key_padding_mask = saved_state['prev_key_padding_mask']
|
412 |
+
if static_kv:
|
413 |
+
key_padding_mask = prev_key_padding_mask
|
414 |
+
else:
|
415 |
+
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
|
416 |
+
|
417 |
+
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
418 |
+
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
419 |
+
saved_state['prev_key_padding_mask'] = key_padding_mask
|
420 |
+
|
421 |
+
self._set_input_buffer(incremental_state, saved_state)
|
422 |
+
|
423 |
+
src_len = k.size(1)
|
424 |
+
|
425 |
+
# This is part of a workaround to get around fork/join parallelism
|
426 |
+
# not supporting Optional types.
|
427 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
428 |
+
key_padding_mask = None
|
429 |
+
|
430 |
+
if key_padding_mask is not None:
|
431 |
+
assert key_padding_mask.size(0) == bsz
|
432 |
+
assert key_padding_mask.size(1) == src_len
|
433 |
+
|
434 |
+
if self.add_zero_attn:
|
435 |
+
src_len += 1
|
436 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
437 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
438 |
+
if attn_mask is not None:
|
439 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
440 |
+
if key_padding_mask is not None:
|
441 |
+
key_padding_mask = torch.cat(
|
442 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
443 |
+
|
444 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
445 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
446 |
+
|
447 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
448 |
+
|
449 |
+
if attn_mask is not None:
|
450 |
+
if len(attn_mask.shape) == 2:
|
451 |
+
attn_mask = attn_mask.unsqueeze(0)
|
452 |
+
elif len(attn_mask.shape) == 3:
|
453 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
454 |
+
bsz * self.num_heads, tgt_len, src_len)
|
455 |
+
attn_weights = attn_weights + attn_mask
|
456 |
+
|
457 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
458 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
459 |
+
attn_weights = attn_weights.masked_fill(
|
460 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
461 |
+
-1e8,
|
462 |
+
)
|
463 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
464 |
+
|
465 |
+
if key_padding_mask is not None:
|
466 |
+
# don't attend to padding symbols
|
467 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
468 |
+
attn_weights = attn_weights.masked_fill(
|
469 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
470 |
+
-1e8,
|
471 |
+
)
|
472 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
473 |
+
|
474 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
475 |
+
|
476 |
+
if before_softmax:
|
477 |
+
return attn_weights, v
|
478 |
+
|
479 |
+
attn_weights_float = softmax(attn_weights, dim=-1)
|
480 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
481 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
482 |
+
|
483 |
+
if reset_attn_weight is not None:
|
484 |
+
if reset_attn_weight:
|
485 |
+
self.last_attn_probs = attn_probs.detach()
|
486 |
+
else:
|
487 |
+
assert self.last_attn_probs is not None
|
488 |
+
attn_probs = self.last_attn_probs
|
489 |
+
attn = torch.bmm(attn_probs, v)
|
490 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
491 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
492 |
+
attn = self.out_proj(attn)
|
493 |
+
|
494 |
+
if need_weights:
|
495 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
496 |
+
if not need_head_weights:
|
497 |
+
# average attention weights over heads
|
498 |
+
attn_weights = attn_weights.mean(dim=0)
|
499 |
+
else:
|
500 |
+
attn_weights = None
|
501 |
+
|
502 |
+
return attn, (attn_weights, attn_logits)
|
503 |
+
|
504 |
+
def in_proj_qkv(self, query):
|
505 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
506 |
+
|
507 |
+
def in_proj_q(self, query):
|
508 |
+
if self.qkv_same_dim:
|
509 |
+
return self._in_proj(query, end=self.embed_dim)
|
510 |
+
else:
|
511 |
+
bias = self.in_proj_bias
|
512 |
+
if bias is not None:
|
513 |
+
bias = bias[:self.embed_dim]
|
514 |
+
return F.linear(query, self.q_proj_weight, bias)
|
515 |
+
|
516 |
+
def in_proj_k(self, key):
|
517 |
+
if self.qkv_same_dim:
|
518 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
519 |
+
else:
|
520 |
+
weight = self.k_proj_weight
|
521 |
+
bias = self.in_proj_bias
|
522 |
+
if bias is not None:
|
523 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
524 |
+
return F.linear(key, weight, bias)
|
525 |
+
|
526 |
+
def in_proj_v(self, value):
|
527 |
+
if self.qkv_same_dim:
|
528 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
529 |
+
else:
|
530 |
+
weight = self.v_proj_weight
|
531 |
+
bias = self.in_proj_bias
|
532 |
+
if bias is not None:
|
533 |
+
bias = bias[2 * self.embed_dim:]
|
534 |
+
return F.linear(value, weight, bias)
|
535 |
+
|
536 |
+
def _in_proj(self, input, start=0, end=None):
|
537 |
+
weight = self.in_proj_weight
|
538 |
+
bias = self.in_proj_bias
|
539 |
+
weight = weight[start:end, :]
|
540 |
+
if bias is not None:
|
541 |
+
bias = bias[start:end]
|
542 |
+
return F.linear(input, weight, bias)
|
543 |
+
|
544 |
+
def _get_input_buffer(self, incremental_state):
|
545 |
+
return get_incremental_state(
|
546 |
+
self,
|
547 |
+
incremental_state,
|
548 |
+
'attn_state',
|
549 |
+
) or {}
|
550 |
+
|
551 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
552 |
+
set_incremental_state(
|
553 |
+
self,
|
554 |
+
incremental_state,
|
555 |
+
'attn_state',
|
556 |
+
buffer,
|
557 |
+
)
|
558 |
+
|
559 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
560 |
+
return attn_weights
|
561 |
+
|
562 |
+
def clear_buffer(self, incremental_state=None):
|
563 |
+
if incremental_state is not None:
|
564 |
+
saved_state = self._get_input_buffer(incremental_state)
|
565 |
+
if 'prev_key' in saved_state:
|
566 |
+
del saved_state['prev_key']
|
567 |
+
if 'prev_value' in saved_state:
|
568 |
+
del saved_state['prev_value']
|
569 |
+
self._set_input_buffer(incremental_state, saved_state)
|
570 |
+
|
571 |
+
|
572 |
+
class Swish(torch.autograd.Function):
|
573 |
+
@staticmethod
|
574 |
+
def forward(ctx, i):
|
575 |
+
result = i * torch.sigmoid(i)
|
576 |
+
ctx.save_for_backward(i)
|
577 |
+
return result
|
578 |
+
|
579 |
+
@staticmethod
|
580 |
+
def backward(ctx, grad_output):
|
581 |
+
i = ctx.saved_variables[0]
|
582 |
+
sigmoid_i = torch.sigmoid(i)
|
583 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
584 |
+
|
585 |
+
|
586 |
+
class CustomSwish(nn.Module):
|
587 |
+
def forward(self, input_tensor):
|
588 |
+
return Swish.apply(input_tensor)
|
589 |
+
|
590 |
+
|
591 |
+
class TransformerFFNLayer(nn.Module):
|
592 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
|
593 |
+
super().__init__()
|
594 |
+
self.kernel_size = kernel_size
|
595 |
+
self.dropout = dropout
|
596 |
+
self.act = act
|
597 |
+
if padding == 'SAME':
|
598 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
|
599 |
+
elif padding == 'LEFT':
|
600 |
+
self.ffn_1 = nn.Sequential(
|
601 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
602 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size)
|
603 |
+
)
|
604 |
+
self.ffn_2 = Linear(filter_size, hidden_size)
|
605 |
+
if self.act == 'swish':
|
606 |
+
self.swish_fn = CustomSwish()
|
607 |
+
|
608 |
+
def forward(self, x, incremental_state=None):
|
609 |
+
# x: T x B x C
|
610 |
+
if incremental_state is not None:
|
611 |
+
saved_state = self._get_input_buffer(incremental_state)
|
612 |
+
if 'prev_input' in saved_state:
|
613 |
+
prev_input = saved_state['prev_input']
|
614 |
+
x = torch.cat((prev_input, x), dim=0)
|
615 |
+
x = x[-self.kernel_size:]
|
616 |
+
saved_state['prev_input'] = x
|
617 |
+
self._set_input_buffer(incremental_state, saved_state)
|
618 |
+
|
619 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
620 |
+
x = x * self.kernel_size ** -0.5
|
621 |
+
|
622 |
+
if incremental_state is not None:
|
623 |
+
x = x[-1:]
|
624 |
+
if self.act == 'gelu':
|
625 |
+
x = F.gelu(x)
|
626 |
+
if self.act == 'relu':
|
627 |
+
x = F.relu(x)
|
628 |
+
if self.act == 'swish':
|
629 |
+
x = self.swish_fn(x)
|
630 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
631 |
+
x = self.ffn_2(x)
|
632 |
+
return x
|
633 |
+
|
634 |
+
def _get_input_buffer(self, incremental_state):
|
635 |
+
return get_incremental_state(
|
636 |
+
self,
|
637 |
+
incremental_state,
|
638 |
+
'f',
|
639 |
+
) or {}
|
640 |
+
|
641 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
642 |
+
set_incremental_state(
|
643 |
+
self,
|
644 |
+
incremental_state,
|
645 |
+
'f',
|
646 |
+
buffer,
|
647 |
+
)
|
648 |
+
|
649 |
+
def clear_buffer(self, incremental_state):
|
650 |
+
if incremental_state is not None:
|
651 |
+
saved_state = self._get_input_buffer(incremental_state)
|
652 |
+
if 'prev_input' in saved_state:
|
653 |
+
del saved_state['prev_input']
|
654 |
+
self._set_input_buffer(incremental_state, saved_state)
|
655 |
+
|
656 |
+
|
657 |
+
class BatchNorm1dTBC(nn.Module):
|
658 |
+
def __init__(self, c):
|
659 |
+
super(BatchNorm1dTBC, self).__init__()
|
660 |
+
self.bn = nn.BatchNorm1d(c)
|
661 |
+
|
662 |
+
def forward(self, x):
|
663 |
+
"""
|
664 |
+
|
665 |
+
:param x: [T, B, C]
|
666 |
+
:return: [T, B, C]
|
667 |
+
"""
|
668 |
+
x = x.permute(1, 2, 0) # [B, C, T]
|
669 |
+
x = self.bn(x) # [B, C, T]
|
670 |
+
x = x.permute(2, 0, 1) # [T, B, C]
|
671 |
+
return x
|
672 |
+
|
673 |
+
|
674 |
+
class EncSALayer(nn.Module):
|
675 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
676 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
|
677 |
+
super().__init__()
|
678 |
+
self.c = c
|
679 |
+
self.dropout = dropout
|
680 |
+
self.num_heads = num_heads
|
681 |
+
if num_heads > 0:
|
682 |
+
if norm == 'ln':
|
683 |
+
self.layer_norm1 = LayerNorm(c)
|
684 |
+
elif norm == 'bn':
|
685 |
+
self.layer_norm1 = BatchNorm1dTBC(c)
|
686 |
+
elif norm == 'gn':
|
687 |
+
self.layer_norm1 = GroupNorm1DTBC(8, c)
|
688 |
+
self.self_attn = MultiheadAttention(
|
689 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
|
690 |
+
if norm == 'ln':
|
691 |
+
self.layer_norm2 = LayerNorm(c)
|
692 |
+
elif norm == 'bn':
|
693 |
+
self.layer_norm2 = BatchNorm1dTBC(c)
|
694 |
+
elif norm == 'gn':
|
695 |
+
self.layer_norm2 = GroupNorm1DTBC(8, c)
|
696 |
+
self.ffn = TransformerFFNLayer(
|
697 |
+
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
698 |
+
|
699 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
700 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
701 |
+
if layer_norm_training is not None:
|
702 |
+
self.layer_norm1.training = layer_norm_training
|
703 |
+
self.layer_norm2.training = layer_norm_training
|
704 |
+
if self.num_heads > 0:
|
705 |
+
residual = x
|
706 |
+
x = self.layer_norm1(x)
|
707 |
+
x, _, = self.self_attn(
|
708 |
+
query=x,
|
709 |
+
key=x,
|
710 |
+
value=x,
|
711 |
+
key_padding_mask=encoder_padding_mask
|
712 |
+
)
|
713 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
714 |
+
x = residual + x
|
715 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
716 |
+
|
717 |
+
residual = x
|
718 |
+
x = self.layer_norm2(x)
|
719 |
+
x = self.ffn(x)
|
720 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
721 |
+
x = residual + x
|
722 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
723 |
+
return x
|
724 |
+
|
725 |
+
|
726 |
+
class DecSALayer(nn.Module):
|
727 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
|
728 |
+
kernel_size=9, act='gelu', norm='ln'):
|
729 |
+
super().__init__()
|
730 |
+
self.c = c
|
731 |
+
self.dropout = dropout
|
732 |
+
if norm == 'ln':
|
733 |
+
self.layer_norm1 = LayerNorm(c)
|
734 |
+
elif norm == 'gn':
|
735 |
+
self.layer_norm1 = GroupNorm1DTBC(8, c)
|
736 |
+
self.self_attn = MultiheadAttention(
|
737 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
738 |
+
)
|
739 |
+
if norm == 'ln':
|
740 |
+
self.layer_norm2 = LayerNorm(c)
|
741 |
+
elif norm == 'gn':
|
742 |
+
self.layer_norm2 = GroupNorm1DTBC(8, c)
|
743 |
+
self.encoder_attn = MultiheadAttention(
|
744 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
745 |
+
)
|
746 |
+
if norm == 'ln':
|
747 |
+
self.layer_norm3 = LayerNorm(c)
|
748 |
+
elif norm == 'gn':
|
749 |
+
self.layer_norm3 = GroupNorm1DTBC(8, c)
|
750 |
+
self.ffn = TransformerFFNLayer(
|
751 |
+
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
752 |
+
|
753 |
+
def forward(
|
754 |
+
self,
|
755 |
+
x,
|
756 |
+
encoder_out=None,
|
757 |
+
encoder_padding_mask=None,
|
758 |
+
incremental_state=None,
|
759 |
+
self_attn_mask=None,
|
760 |
+
self_attn_padding_mask=None,
|
761 |
+
attn_out=None,
|
762 |
+
reset_attn_weight=None,
|
763 |
+
**kwargs,
|
764 |
+
):
|
765 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
766 |
+
if layer_norm_training is not None:
|
767 |
+
self.layer_norm1.training = layer_norm_training
|
768 |
+
self.layer_norm2.training = layer_norm_training
|
769 |
+
self.layer_norm3.training = layer_norm_training
|
770 |
+
residual = x
|
771 |
+
x = self.layer_norm1(x)
|
772 |
+
x, _ = self.self_attn(
|
773 |
+
query=x,
|
774 |
+
key=x,
|
775 |
+
value=x,
|
776 |
+
key_padding_mask=self_attn_padding_mask,
|
777 |
+
incremental_state=incremental_state,
|
778 |
+
attn_mask=self_attn_mask
|
779 |
+
)
|
780 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
781 |
+
x = residual + x
|
782 |
+
|
783 |
+
attn_logits = None
|
784 |
+
if encoder_out is not None or attn_out is not None:
|
785 |
+
residual = x
|
786 |
+
x = self.layer_norm2(x)
|
787 |
+
if encoder_out is not None:
|
788 |
+
x, attn = self.encoder_attn(
|
789 |
+
query=x,
|
790 |
+
key=encoder_out,
|
791 |
+
value=encoder_out,
|
792 |
+
key_padding_mask=encoder_padding_mask,
|
793 |
+
incremental_state=incremental_state,
|
794 |
+
static_kv=True,
|
795 |
+
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
|
796 |
+
'enc_dec_attn_constraint_mask'),
|
797 |
+
reset_attn_weight=reset_attn_weight
|
798 |
+
)
|
799 |
+
attn_logits = attn[1]
|
800 |
+
elif attn_out is not None:
|
801 |
+
x = self.encoder_attn.in_proj_v(attn_out)
|
802 |
+
if encoder_out is not None or attn_out is not None:
|
803 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
804 |
+
x = residual + x
|
805 |
+
|
806 |
+
residual = x
|
807 |
+
x = self.layer_norm3(x)
|
808 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
809 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
810 |
+
x = residual + x
|
811 |
+
return x, attn_logits
|
812 |
+
|
813 |
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
814 |
+
self.encoder_attn.clear_buffer(incremental_state)
|
815 |
+
self.ffn.clear_buffer(incremental_state)
|
816 |
+
|
817 |
+
def set_buffer(self, name, tensor, incremental_state):
|
818 |
+
return set_incremental_state(self, incremental_state, name, tensor)
|
819 |
+
|
820 |
+
|
821 |
+
class ConvBlock(nn.Module):
|
822 |
+
def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
|
823 |
+
super().__init__()
|
824 |
+
self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
|
825 |
+
self.norm = norm
|
826 |
+
if self.norm == 'bn':
|
827 |
+
self.norm = nn.BatchNorm1d(n_chans)
|
828 |
+
elif self.norm == 'in':
|
829 |
+
self.norm = nn.InstanceNorm1d(n_chans, affine=True)
|
830 |
+
elif self.norm == 'gn':
|
831 |
+
self.norm = nn.GroupNorm(n_chans // 16, n_chans)
|
832 |
+
elif self.norm == 'ln':
|
833 |
+
self.norm = LayerNorm(n_chans // 16, n_chans)
|
834 |
+
elif self.norm == 'wn':
|
835 |
+
self.conv = torch.nn.utils.weight_norm(self.conv.conv)
|
836 |
+
self.dropout = nn.Dropout(dropout)
|
837 |
+
self.relu = nn.ReLU()
|
838 |
+
|
839 |
+
def forward(self, x):
|
840 |
+
"""
|
841 |
+
|
842 |
+
:param x: [B, C, T]
|
843 |
+
:return: [B, C, T]
|
844 |
+
"""
|
845 |
+
x = self.conv(x)
|
846 |
+
if not isinstance(self.norm, str):
|
847 |
+
if self.norm == 'none':
|
848 |
+
pass
|
849 |
+
elif self.norm == 'ln':
|
850 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
851 |
+
else:
|
852 |
+
x = self.norm(x)
|
853 |
+
x = self.relu(x)
|
854 |
+
x = self.dropout(x)
|
855 |
+
return x
|
856 |
+
|
857 |
+
|
858 |
+
class ConvStacks(nn.Module):
|
859 |
+
def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
|
860 |
+
dropout=0, strides=None, res=True):
|
861 |
+
super().__init__()
|
862 |
+
self.conv = torch.nn.ModuleList()
|
863 |
+
self.kernel_size = kernel_size
|
864 |
+
self.res = res
|
865 |
+
self.in_proj = Linear(idim, n_chans)
|
866 |
+
if strides is None:
|
867 |
+
strides = [1] * n_layers
|
868 |
+
else:
|
869 |
+
assert len(strides) == n_layers
|
870 |
+
for idx in range(n_layers):
|
871 |
+
self.conv.append(ConvBlock(
|
872 |
+
n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
|
873 |
+
self.out_proj = Linear(n_chans, odim)
|
874 |
+
|
875 |
+
def forward(self, x, return_hiddens=False):
|
876 |
+
"""
|
877 |
+
|
878 |
+
:param x: [B, T, H]
|
879 |
+
:return: [B, T, H]
|
880 |
+
"""
|
881 |
+
x = self.in_proj(x)
|
882 |
+
x = x.transpose(1, -1) # (B, idim, Tmax)
|
883 |
+
hiddens = []
|
884 |
+
for f in self.conv:
|
885 |
+
x_ = f(x)
|
886 |
+
x = x + x_ if self.res else x_ # (B, C, Tmax)
|
887 |
+
hiddens.append(x)
|
888 |
+
x = x.transpose(1, -1)
|
889 |
+
x = self.out_proj(x) # (B, Tmax, H)
|
890 |
+
if return_hiddens:
|
891 |
+
hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
|
892 |
+
return x, hiddens
|
893 |
+
return x
|
894 |
+
|
895 |
+
|
896 |
+
class ConvGlobalStacks(nn.Module):
|
897 |
+
def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0,
|
898 |
+
strides=[2, 2, 2, 2, 2]):
|
899 |
+
super().__init__()
|
900 |
+
self.conv = torch.nn.ModuleList()
|
901 |
+
self.pooling = torch.nn.ModuleList()
|
902 |
+
self.kernel_size = kernel_size
|
903 |
+
self.in_proj = Linear(idim, n_chans)
|
904 |
+
for idx in range(n_layers):
|
905 |
+
self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx],
|
906 |
+
norm=norm, dropout=dropout))
|
907 |
+
self.pooling.append(nn.MaxPool1d(strides[idx]))
|
908 |
+
self.out_proj = Linear(n_chans, odim)
|
909 |
+
|
910 |
+
def forward(self, x):
|
911 |
+
"""
|
912 |
+
|
913 |
+
:param x: [B, T, H]
|
914 |
+
:return: [B, T, H]
|
915 |
+
"""
|
916 |
+
x = self.in_proj(x)
|
917 |
+
x = x.transpose(1, -1) # (B, idim, Tmax)
|
918 |
+
for f, p in zip(self.conv, self.pooling):
|
919 |
+
x = f(x) # (B, C, T)
|
920 |
+
x = x.transpose(1, -1)
|
921 |
+
x = self.out_proj(x.mean(1)) # (B, H)
|
922 |
+
return x
|
923 |
+
|
924 |
+
|
925 |
+
class ConvDecoder(nn.Module):
|
926 |
+
def __init__(self, c, dropout, kernel_size=9, act='gelu'):
|
927 |
+
super().__init__()
|
928 |
+
self.c = c
|
929 |
+
self.dropout = dropout
|
930 |
+
|
931 |
+
self.pre_convs = nn.ModuleList()
|
932 |
+
self.pre_lns = nn.ModuleList()
|
933 |
+
for i in range(2):
|
934 |
+
self.pre_convs.append(TransformerFFNLayer(
|
935 |
+
c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
|
936 |
+
self.pre_lns.append(LayerNorm(c))
|
937 |
+
|
938 |
+
self.layer_norm_attn = LayerNorm(c)
|
939 |
+
self.encoder_attn = MultiheadAttention(c, 1, encoder_decoder_attention=True, bias=False)
|
940 |
+
|
941 |
+
self.post_convs = nn.ModuleList()
|
942 |
+
self.post_lns = nn.ModuleList()
|
943 |
+
for i in range(8):
|
944 |
+
self.post_convs.append(TransformerFFNLayer(
|
945 |
+
c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
|
946 |
+
self.post_lns.append(LayerNorm(c))
|
947 |
+
|
948 |
+
def forward(
|
949 |
+
self,
|
950 |
+
x,
|
951 |
+
encoder_out=None,
|
952 |
+
encoder_padding_mask=None,
|
953 |
+
incremental_state=None,
|
954 |
+
**kwargs,
|
955 |
+
):
|
956 |
+
attn_logits = None
|
957 |
+
for conv, ln in zip(self.pre_convs, self.pre_lns):
|
958 |
+
residual = x
|
959 |
+
x = ln(x)
|
960 |
+
x = conv(x) + residual
|
961 |
+
if encoder_out is not None:
|
962 |
+
residual = x
|
963 |
+
x = self.layer_norm_attn(x)
|
964 |
+
x, attn = self.encoder_attn(
|
965 |
+
query=x,
|
966 |
+
key=encoder_out,
|
967 |
+
value=encoder_out,
|
968 |
+
key_padding_mask=encoder_padding_mask,
|
969 |
+
incremental_state=incremental_state,
|
970 |
+
static_kv=True,
|
971 |
+
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
|
972 |
+
'enc_dec_attn_constraint_mask'),
|
973 |
+
)
|
974 |
+
attn_logits = attn[1]
|
975 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
976 |
+
x = residual + x
|
977 |
+
for conv, ln in zip(self.post_convs, self.post_lns):
|
978 |
+
residual = x
|
979 |
+
x = ln(x)
|
980 |
+
x = conv(x) + residual
|
981 |
+
return x, attn_logits
|
982 |
+
|
983 |
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
984 |
+
self.encoder_attn.clear_buffer(incremental_state)
|
985 |
+
self.ffn.clear_buffer(incremental_state)
|
986 |
+
|
987 |
+
def set_buffer(self, name, tensor, incremental_state):
|
988 |
+
return set_incremental_state(self, incremental_state, name, tensor)
|
modules/audio2motion/transformer_models.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import isin
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from modules.audio2motion.transformer_base import *
|
5 |
+
|
6 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 2000
|
7 |
+
DEFAULT_MAX_TARGET_POSITIONS = 2000
|
8 |
+
|
9 |
+
|
10 |
+
class TransformerEncoderLayer(nn.Module):
|
11 |
+
def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
|
12 |
+
super().__init__()
|
13 |
+
self.hidden_size = hidden_size
|
14 |
+
self.dropout = dropout
|
15 |
+
self.num_heads = num_heads
|
16 |
+
self.op = EncSALayer(
|
17 |
+
hidden_size, num_heads, dropout=dropout,
|
18 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
19 |
+
kernel_size=kernel_size
|
20 |
+
if kernel_size is not None else 9,
|
21 |
+
padding='SAME',
|
22 |
+
norm=norm, act='gelu'
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x, **kwargs):
|
26 |
+
return self.op(x, **kwargs)
|
27 |
+
|
28 |
+
|
29 |
+
######################
|
30 |
+
# fastspeech modules
|
31 |
+
######################
|
32 |
+
class LayerNorm(torch.nn.LayerNorm):
|
33 |
+
"""Layer normalization module.
|
34 |
+
:param int nout: output dim size
|
35 |
+
:param int dim: dimension to be normalized
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
39 |
+
"""Construct an LayerNorm object."""
|
40 |
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
41 |
+
self.dim = dim
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""Apply layer normalization.
|
45 |
+
:param torch.Tensor x: input tensor
|
46 |
+
:return: layer normalized tensor
|
47 |
+
:rtype torch.Tensor
|
48 |
+
"""
|
49 |
+
if self.dim == -1:
|
50 |
+
return super(LayerNorm, self).forward(x)
|
51 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
52 |
+
|
53 |
+
|
54 |
+
class FFTBlocks(nn.Module):
|
55 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None,
|
56 |
+
num_heads=2, use_pos_embed=True, use_last_norm=True, norm='ln',
|
57 |
+
use_pos_embed_alpha=True):
|
58 |
+
super().__init__()
|
59 |
+
self.num_layers = num_layers
|
60 |
+
embed_dim = self.hidden_size = hidden_size
|
61 |
+
self.dropout = dropout if dropout is not None else 0.1
|
62 |
+
self.use_pos_embed = use_pos_embed
|
63 |
+
self.use_last_norm = use_last_norm
|
64 |
+
if use_pos_embed:
|
65 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
66 |
+
self.padding_idx = 0
|
67 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
68 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
69 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
70 |
+
)
|
71 |
+
|
72 |
+
self.layers = nn.ModuleList([])
|
73 |
+
self.layers.extend([
|
74 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
75 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads,
|
76 |
+
norm=norm)
|
77 |
+
for _ in range(self.num_layers)
|
78 |
+
])
|
79 |
+
if self.use_last_norm:
|
80 |
+
if norm == 'ln':
|
81 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
82 |
+
elif norm == 'bn':
|
83 |
+
self.layer_norm = BatchNorm1dTBC(embed_dim)
|
84 |
+
elif norm == 'gn':
|
85 |
+
self.layer_norm = GroupNorm1DTBC(8, embed_dim)
|
86 |
+
else:
|
87 |
+
self.layer_norm = None
|
88 |
+
|
89 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
90 |
+
"""
|
91 |
+
:param x: [B, T, C]
|
92 |
+
:param padding_mask: [B, T]
|
93 |
+
:return: [B, T, C] or [L, B, T, C]
|
94 |
+
"""
|
95 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
96 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
97 |
+
if self.use_pos_embed:
|
98 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
99 |
+
x = x + positions
|
100 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
101 |
+
# B x T x C -> T x B x C
|
102 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
103 |
+
hiddens = []
|
104 |
+
for layer in self.layers:
|
105 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
106 |
+
hiddens.append(x)
|
107 |
+
if self.use_last_norm:
|
108 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
109 |
+
if return_hiddens:
|
110 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
111 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
112 |
+
else:
|
113 |
+
x = x.transpose(0, 1) # [B, T, C]
|
114 |
+
return x
|
115 |
+
|
116 |
+
class SequentialSA(nn.Module):
|
117 |
+
def __init__(self,layers):
|
118 |
+
super(SequentialSA,self).__init__()
|
119 |
+
self.layers = nn.ModuleList(layers)
|
120 |
+
|
121 |
+
def forward(self,x,x_mask):
|
122 |
+
"""
|
123 |
+
x: [batch, T, H]
|
124 |
+
x_mask: [batch, T]
|
125 |
+
"""
|
126 |
+
pad_mask = 1. - x_mask
|
127 |
+
for layer in self.layers:
|
128 |
+
if isinstance(layer, EncSALayer):
|
129 |
+
x = x.permute(1,0,2)
|
130 |
+
x = layer(x,pad_mask)
|
131 |
+
x = x.permute(1,0,2)
|
132 |
+
elif isinstance(layer, nn.Linear):
|
133 |
+
x = layer(x) * x_mask.unsqueeze(2)
|
134 |
+
elif isinstance(layer, nn.AvgPool1d):
|
135 |
+
x = x.permute(0,2,1)
|
136 |
+
x = layer(x)
|
137 |
+
x = x.permute(0,2,1)
|
138 |
+
elif isinstance(layer, nn.PReLU):
|
139 |
+
bs, t, hid = x.shape
|
140 |
+
x = x.reshape([bs*t,hid])
|
141 |
+
x = layer(x)
|
142 |
+
x = x.reshape([bs, t, hid])
|
143 |
+
else: # Relu
|
144 |
+
x = layer(x)
|
145 |
+
|
146 |
+
return x
|
147 |
+
|
148 |
+
class TransformerStyleFusionModel(nn.Module):
|
149 |
+
def __init__(self, num_heads=4, dropout = 0.1, out_dim = 64):
|
150 |
+
super(TransformerStyleFusionModel, self).__init__()
|
151 |
+
self.audio_layer = SequentialSA([
|
152 |
+
nn.Linear(29, 48),
|
153 |
+
nn.ReLU(48),
|
154 |
+
nn.Linear(48, 128),
|
155 |
+
])
|
156 |
+
|
157 |
+
self.energy_layer = SequentialSA([
|
158 |
+
nn.Linear(1, 16),
|
159 |
+
nn.ReLU(16),
|
160 |
+
nn.Linear(16, 64),
|
161 |
+
])
|
162 |
+
|
163 |
+
self.backbone1 = FFTBlocks(hidden_size=192,num_layers=3)
|
164 |
+
|
165 |
+
self.sty_encoder = nn.Sequential(*[
|
166 |
+
nn.Linear(135, 64),
|
167 |
+
nn.ReLU(),
|
168 |
+
nn.Linear(64, 128)
|
169 |
+
])
|
170 |
+
|
171 |
+
self.backbone2 = FFTBlocks(hidden_size=320,num_layers=3)
|
172 |
+
|
173 |
+
self.out_layer = SequentialSA([
|
174 |
+
nn.AvgPool1d(kernel_size=2,stride=2,padding=0), #[b,hid,t_audio]=>[b,hid,t_audio//2]
|
175 |
+
nn.Linear(320,out_dim),
|
176 |
+
nn.PReLU(out_dim),
|
177 |
+
nn.Linear(out_dim,out_dim),
|
178 |
+
])
|
179 |
+
|
180 |
+
self.dropout = nn.Dropout(p = dropout)
|
181 |
+
|
182 |
+
def forward(self, audio, energy, style, x_mask, y_mask):
|
183 |
+
pad_mask = 1. - x_mask
|
184 |
+
audio_feat = self.audio_layer(audio, x_mask)
|
185 |
+
energy_feat = self.energy_layer(energy, x_mask)
|
186 |
+
feat = torch.cat((audio_feat, energy_feat), dim=-1) # [batch, T, H=48+16]
|
187 |
+
feat = self.backbone1(feat, pad_mask)
|
188 |
+
feat = self.dropout(feat)
|
189 |
+
|
190 |
+
sty_feat = self.sty_encoder(style) # [batch,135]=>[batch, H=64]
|
191 |
+
sty_feat = sty_feat.unsqueeze(1).repeat(1, feat.shape[1], 1) # [batch, T, H=64]
|
192 |
+
|
193 |
+
feat = torch.cat([feat, sty_feat], dim=-1) # [batch, T, H=64+64]
|
194 |
+
feat = self.backbone2(feat, pad_mask) # [batch, T, H=128]
|
195 |
+
out = self.out_layer(feat, y_mask) # [batch, T//2, H=out_dim]
|
196 |
+
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
model = TransformerStyleFusionModel()
|
202 |
+
audio = torch.rand(4,200,29) # [B,T,H]
|
203 |
+
energy = torch.rand(4,200,1) # [B,T,H]
|
204 |
+
style = torch.ones(4,135) # [B,T]
|
205 |
+
x_mask = torch.ones(4,200) # [B,T]
|
206 |
+
x_mask[3,10:] = 0
|
207 |
+
ret = model(audio,energy,style, x_mask)
|
208 |
+
print(" ")
|
modules/audio2motion/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def squeeze(x, x_mask=None, n_sqz=2):
|
5 |
+
b, c, t = x.size()
|
6 |
+
|
7 |
+
t = (t // n_sqz) * n_sqz
|
8 |
+
x = x[:, :, :t]
|
9 |
+
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
|
10 |
+
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
|
11 |
+
|
12 |
+
if x_mask is not None:
|
13 |
+
x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
|
14 |
+
else:
|
15 |
+
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
|
16 |
+
return x_sqz * x_mask, x_mask
|
17 |
+
|
18 |
+
|
19 |
+
def unsqueeze(x, x_mask=None, n_sqz=2):
|
20 |
+
b, c, t = x.size()
|
21 |
+
|
22 |
+
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
|
23 |
+
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
|
24 |
+
|
25 |
+
if x_mask is not None:
|
26 |
+
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
|
27 |
+
else:
|
28 |
+
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
|
29 |
+
return x_unsqz * x_mask, x_mask
|