vinthony commited on
Commit
a22eb82
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. LICENSE +21 -0
  3. README.md +194 -0
  4. app.py +86 -0
  5. config/auido2exp.yaml +58 -0
  6. config/auido2pose.yaml +49 -0
  7. config/facerender.yaml +45 -0
  8. inference.py +134 -0
  9. modules/__pycache__/gfpgan_inference.cpython-38.pyc +0 -0
  10. modules/__pycache__/gfpgan_inference.cpython-39.pyc +0 -0
  11. modules/__pycache__/sadtalker_test.cpython-38.pyc +0 -0
  12. modules/__pycache__/sadtalker_test.cpython-39.pyc +0 -0
  13. modules/__pycache__/text2speech.cpython-38.pyc +0 -0
  14. modules/__pycache__/text2speech.cpython-39.pyc +0 -0
  15. modules/gfpgan_inference.py +36 -0
  16. modules/sadtalker_test.py +95 -0
  17. modules/text2speech.py +12 -0
  18. packages.txt +1 -0
  19. requirements.txt +17 -0
  20. src/__pycache__/generate_batch.cpython-38.pyc +0 -0
  21. src/__pycache__/generate_facerender_batch.cpython-38.pyc +0 -0
  22. src/__pycache__/test_audio2coeff.cpython-38.pyc +0 -0
  23. src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
  24. src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
  25. src/audio2exp_models/audio2exp.py +30 -0
  26. src/audio2exp_models/networks.py +74 -0
  27. src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
  28. src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
  29. src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
  30. src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
  31. src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
  32. src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
  33. src/audio2pose_models/audio2pose.py +93 -0
  34. src/audio2pose_models/audio_encoder.py +64 -0
  35. src/audio2pose_models/cvae.py +149 -0
  36. src/audio2pose_models/discriminator.py +76 -0
  37. src/audio2pose_models/networks.py +140 -0
  38. src/audio2pose_models/res_unet.py +65 -0
  39. src/config/auido2exp.yaml +58 -0
  40. src/config/auido2pose.yaml +49 -0
  41. src/config/facerender.yaml +45 -0
  42. src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc +0 -0
  43. src/face3d/__pycache__/visualize.cpython-38.pyc +0 -0
  44. src/face3d/data/__init__.py +116 -0
  45. src/face3d/data/base_dataset.py +125 -0
  46. src/face3d/data/flist_dataset.py +125 -0
  47. src/face3d/data/image_folder.py +66 -0
  48. src/face3d/data/template_dataset.py +75 -0
  49. src/face3d/extract_kp_videos.py +107 -0
  50. src/face3d/models/__init__.py +67 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Tencent AI Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <h2> 😭 SadTalker: <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2>
4
+
5
+ <a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-2211.14758-red'></a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)
6
+
7
+ <div>
8
+ <a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a>&emsp;
9
+ <a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a>&emsp;
10
+ <a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a>&emsp;
11
+ <a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a>&emsp;
12
+ <a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>&emsp; </br>
13
+ <a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a>&emsp;
14
+ <a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a>&emsp;
15
+ <a target='_blank'>Fei Wang <sup>1</sup> </a>&emsp;
16
+ </div>
17
+ <br>
18
+ <div>
19
+ <sup>1</sup> Xi'an Jiaotong University &emsp; <sup>2</sup> Tencent AI Lab &emsp; <sup>3</sup> Ant Group &emsp;
20
+ </div>
21
+ <br>
22
+ <i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
23
+ <br>
24
+ <br>
25
+
26
+ ![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif)
27
+
28
+ <b>TL;DR: A realistic and stylized talking head video generation method from a single image and audio.</b>
29
+
30
+ <br>
31
+
32
+ </div>
33
+
34
+
35
+ ## 📋 Changelog
36
+
37
+
38
+ - __2023.03.22__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.
39
+
40
+ - __2023.03.22__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`.
41
+ - __2023.03.18__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.
42
+
43
+ - __2023.03.18__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
44
+ - __2023.03.18__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for better visualization performance.
45
+ - __2023.03.14__: Specify the version of package `joblib` to remove the errors in using `librosa`, [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
46
+ &nbsp;&nbsp;&nbsp;&nbsp; <details><summary> Previous Changelogs</summary>
47
+ - 2023.03.06 Solve some bugs in code and errors in installation
48
+ - 2023.03.03 Release the test code for audio-driven single image animation!
49
+ - 2023.02.28 SadTalker has been accepted by CVPR 2023!
50
+
51
+ </details>
52
+
53
+ ## 🎼 Pipeline
54
+ ![main_of_sadtalker](https://user-images.githubusercontent.com/4397546/222490596-4c8a2115-49a7-42ad-a2c3-3bb3288a5f36.png)
55
+
56
+
57
+ ## 🚧 TODO
58
+
59
+ - [x] Generating 2D face from a single Image.
60
+ - [x] Generating 3D face from Audio.
61
+ - [x] Generating 4D free-view talking examples from audio and a single image.
62
+ - [x] Gradio/Colab Demo.
63
+ - [ ] Full body/image Generation.
64
+ - [ ] training code of each componments.
65
+ - [ ] Audio-driven Anime Avatar.
66
+ - [ ] interpolate ChatGPT for a conversation demo 🤔
67
+ - [ ] integrade with stable-diffusion-web-ui. (stay tunning!)
68
+
69
+ https://user-images.githubusercontent.com/4397546/222513483-89161f58-83d0-40e4-8e41-96c32b47bd4e.mp4
70
+
71
+
72
+ ## 🔮 Inference Demo!
73
+
74
+ #### Dependence Installation
75
+
76
+ <details><summary>CLICK ME</summary>
77
+
78
+ ```
79
+ git clone https://github.com/Winfredy/SadTalker.git
80
+ cd SadTalker
81
+ conda create -n sadtalker python=3.8
82
+ source activate sadtalker
83
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
84
+ conda install ffmpeg
85
+ pip install dlib-bin # [dlib-bin is much faster than dlib installation] conda install dlib
86
+ pip install -r requirements.txt
87
+
88
+ ### install gpfgan for enhancer
89
+ pip install git+https://github.com/TencentARC/GFPGAN
90
+
91
+ ```
92
+
93
+ </details>
94
+
95
+ #### Trained Models
96
+ <details><summary>CLICK ME</summary>
97
+
98
+ You can run the following script to put all the models in the right place.
99
+
100
+ ```bash
101
+ bash scripts/download_models.sh
102
+ ```
103
+
104
+ OR download our pre-trained model from [google drive](https://drive.google.com/drive/folders/1Wd88VDoLhVzYsQ30_qDVluQr_Xm46yHT?usp=sharing) or our [github release page](https://github.com/Winfredy/SadTalker/releases/tag/v0.0.1), and then, put it in ./checkpoints.
105
+
106
+ | Model | Description
107
+ | :--- | :----------
108
+ |checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
109
+ |checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
110
+ |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
111
+ |checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
112
+ |checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
113
+ |checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
114
+ |checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/).
115
+ |checkpoints/BFM | 3DMM library file.
116
+ |checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
117
+
118
+ </details>
119
+
120
+ #### Generating 2D face from a single Image
121
+
122
+ ```bash
123
+ python inference.py --driven_audio <audio.wav> \
124
+ --source_image <video.mp4 or picture.png> \
125
+ --batch_size <default equals 2, a larger run faster> \
126
+ --expression_scale <default is 1.0, a larger value will make the motion stronger> \
127
+ --result_dir <a file to store results> \
128
+ --enhancer <default is None, you can choose gfpgan or RestoreFormer>
129
+ ```
130
+
131
+ <!-- ###### The effectness of enhancer `gfpgan`. -->
132
+
133
+ | basic | w/ still mode | w/ exp_scale 1.3 | w/ gfpgan |
134
+ |:-------------: |:-------------: |:-------------: |:-------------: |
135
+ | <video src="https://user-images.githubusercontent.com/4397546/226097707-bef1dd41-403e-48d3-a6e6-6adf923843af.mp4"></video> | <video src='https://user-images.githubusercontent.com/4397546/226804933-b717229f-1919-4bd5-b6af-bea7ab66cad3.mp4'></video> | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226806013-7752c308-8235-4e7a-9465-72d8fc1aa03d.mp4"></video> | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226097717-12a1a2a1-ac0f-428d-b2cb-bd6917aff73e.mp4"></video> |
136
+
137
+ > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
138
+
139
+
140
+ <!-- <video src="./docs/art_0##japanese_still.mp4"></video> -->
141
+
142
+
143
+ #### Generating 3D face from Audio
144
+
145
+
146
+ | Input | Animated 3d face |
147
+ |:-------------: | :-------------: |
148
+ | <img src='examples/source_image/art_0.png' width='200px'> | <video src="https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4"></video> |
149
+
150
+ > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
151
+
152
+ More details to generate the 3d face can be founded [here](docs/face3d.md)
153
+
154
+ #### Generating 4D free-view talking examples from audio and a single image
155
+
156
+ We use `camera_yaw`, `camera_pitch`, `camera_roll` to control camera pose. For example, `--camera_yaw -20 30 10` means the camera yaw degree changes from -20 to 30 and then changes from 30 to 10.
157
+ ```bash
158
+ python inference.py --driven_audio <audio.wav> \
159
+ --source_image <video.mp4 or picture.png> \
160
+ --result_dir <a file to store results> \
161
+ --camera_yaw -20 30 10
162
+ ```
163
+ ![free_view](https://github.com/Winfredy/SadTalker/blob/main/docs/free_view_result.gif)
164
+
165
+
166
+ ## 🛎 Citation
167
+
168
+ If you find our work useful in your research, please consider citing:
169
+
170
+ ```bibtex
171
+ @article{zhang2022sadtalker,
172
+ title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
173
+ author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
174
+ journal={arXiv preprint arXiv:2211.12194},
175
+ year={2022}
176
+ }
177
+ ```
178
+
179
+ ## 💗 Acknowledgements
180
+
181
+ Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, We also use the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.
182
+
183
+
184
+ ## 🥂 Related Works
185
+ - [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
186
+ - [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
187
+ - [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
188
+ - [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
189
+ - [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
190
+ - [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)
191
+
192
+ ## 📢 Disclaimer
193
+
194
+ This is not an official product of Tencent. This repository can only be used for personal/research/non-commercial purposes.
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import tempfile
3
+ import gradio as gr
4
+ from modules.text2speech import text2speech
5
+ from modules.gfpgan_inference import gfpgan
6
+ from modules.sadtalker_test import SadTalker
7
+
8
+ def get_driven_audio(audio):
9
+ if os.path.isfile(audio):
10
+ return audio
11
+ else:
12
+ save_path = tempfile.NamedTemporaryFile(
13
+ delete=False,
14
+ suffix=("." + "wav"),
15
+ )
16
+ gen_audio = text2speech(audio, save_path.name)
17
+ return gen_audio, gen_audio
18
+
19
+ def get_source_image(image):
20
+ return image
21
+
22
+ def sadtalker_demo(result_dir):
23
+
24
+ sad_talker = SadTalker()
25
+ with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
26
+ with gr.Row().style(equal_height=False):
27
+ with gr.Column(variant='panel'):
28
+ with gr.Tabs(elem_id="sadtalker_source_image"):
29
+ source_image = gr.Image(visible=False, type="filepath")
30
+ with gr.TabItem('Upload image'):
31
+ with gr.Row():
32
+ input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
33
+ submit_image = gr.Button('Submit', variant='primary')
34
+ submit_image.click(fn=get_source_image, inputs=input_image, outputs=source_image)
35
+
36
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
37
+ driven_audio = gr.Audio(visible=False, type="filepath")
38
+ with gr.TabItem('Upload audio'):
39
+ with gr.Column(variant='panel'):
40
+ input_audio1 = gr.Audio(label="Input audio", source="upload", type="filepath")
41
+ submit_audio_1 = gr.Button('Submit', variant='primary')
42
+ submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
43
+
44
+ with gr.TabItem('Microphone'):
45
+ with gr.Column(variant='panel'):
46
+ input_audio2 = gr.Audio(label="Recording audio", source="microphone", type="filepath")
47
+ submit_audio_2 = gr.Button('Submit', variant='primary')
48
+ submit_audio_2.click(fn=get_driven_audio, inputs=input_audio2, outputs=driven_audio)
49
+
50
+ with gr.TabItem('TTS'):
51
+ with gr.Column(variant='panel'):
52
+ with gr.Row().style(equal_height=False):
53
+ input_text = gr.Textbox(label="Input text", lines=5, value="Please enter some text in English")
54
+ input_audio3 = gr.Audio(label="Generated audio", type="filepath")
55
+ submit_audio_3 = gr.Button('Submit', variant='primary')
56
+ submit_audio_3.click(fn=get_driven_audio, inputs=input_text, outputs=[input_audio3, driven_audio])
57
+
58
+ with gr.Column(variant='panel'):
59
+ gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
60
+ gen_text = gr.Textbox(visible=False)
61
+ submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
62
+ scale = gr.Slider(minimum=1, maximum=8, step=1, label="GFPGAN scale", value=1)
63
+ new_video = gr.Video(label="New video", format="mp4").style(height=512,width=512)
64
+ change_scale = gr.Button('Restore video', elem_id="restore_video", variant='primary')
65
+
66
+ submit.click(
67
+ fn=sad_talker.test,
68
+ inputs=[source_image,
69
+ driven_audio,
70
+ gr.Textbox(value=result_dir, visible=False)],
71
+ outputs=[gen_video, gen_text]
72
+ )
73
+ change_scale.click(gfpgan, [scale, gen_text], new_video)
74
+
75
+ return sadtalker_interface
76
+
77
+
78
+ if __name__ == "__main__":
79
+
80
+ current_code_path = sys.argv[0]
81
+ current_root_dir = os.path.split(current_code_path)[0]
82
+ sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
83
+ demo = sadtalker_demo(sadtalker_result_dir)
84
+ demo.launch()
85
+
86
+
config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
inference.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from time import strftime
3
+ import os, sys, time
4
+ from argparse import ArgumentParser
5
+
6
+ from src.utils.preprocess import CropAndExtract
7
+ from src.test_audio2coeff import Audio2Coeff
8
+ from src.facerender.animate import AnimateFromCoeff
9
+ from src.generate_batch import get_data
10
+ from src.generate_facerender_batch import get_facerender_data
11
+
12
+ def main(args):
13
+ #torch.backends.cudnn.enabled = False
14
+
15
+ pic_path = args.source_image
16
+ audio_path = args.driven_audio
17
+ save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
18
+ os.makedirs(save_dir, exist_ok=True)
19
+ pose_style = args.pose_style
20
+ device = args.device
21
+ batch_size = args.batch_size
22
+ camera_yaw_list = args.camera_yaw
23
+ camera_pitch_list = args.camera_pitch
24
+ camera_roll_list = args.camera_roll
25
+
26
+ current_code_path = sys.argv[0]
27
+ current_root_path = os.path.split(current_code_path)[0]
28
+
29
+ os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)
30
+
31
+ path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
32
+ path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
33
+ dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
34
+ wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')
35
+
36
+ audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
37
+ audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
38
+
39
+ audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
40
+ audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
41
+
42
+ free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
43
+ mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
44
+ facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
45
+
46
+ #init model
47
+ print(path_of_net_recon_model)
48
+ preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
49
+
50
+ print(audio2pose_checkpoint)
51
+ print(audio2exp_checkpoint)
52
+ audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
53
+ audio2exp_checkpoint, audio2exp_yaml_path,
54
+ wav2lip_checkpoint, device)
55
+
56
+ print(free_view_checkpoint)
57
+ print(mapping_checkpoint)
58
+ animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
59
+ facerender_yaml_path, device)
60
+
61
+ #crop image and extract 3dmm from image
62
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
63
+ os.makedirs(first_frame_dir, exist_ok=True)
64
+ first_coeff_path, crop_pic_path = preprocess_model.generate(pic_path, first_frame_dir)
65
+ if first_coeff_path is None:
66
+ print("Can't get the coeffs of the input")
67
+ return
68
+
69
+ #audio2ceoff
70
+ batch = get_data(first_coeff_path, audio_path, device)
71
+ coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style)
72
+
73
+ # 3dface render
74
+ if args.face3dvis:
75
+ from src.face3d.visualize import gen_composed_video
76
+ gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
77
+
78
+ #coeff2video
79
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
80
+ batch_size, camera_yaw_list, camera_pitch_list, camera_roll_list,
81
+ expression_scale=args.expression_scale, still_mode=args.still)
82
+
83
+ animate_from_coeff.generate(data, save_dir, enhancer=args.enhancer)
84
+ video_name = data['video_name']
85
+
86
+ if args.enhancer is not None:
87
+ print(f'The generated video is named {video_name}_enhanced in {save_dir}')
88
+ else:
89
+ print(f'The generated video is named {video_name} in {save_dir}')
90
+
91
+ return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
92
+
93
+
94
+ if __name__ == '__main__':
95
+
96
+ parser = ArgumentParser()
97
+ parser.add_argument("--driven_audio", default='./examples/driven_audio/japanese.wav', help="path to driven audio")
98
+ parser.add_argument("--source_image", default='./examples/source_image/art_0.png', help="path to source image")
99
+ parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
100
+ parser.add_argument("--result_dir", default='./results', help="path to output")
101
+ parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
102
+ parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
103
+ parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
104
+ parser.add_argument('--camera_yaw', nargs='+', type=int, default=[0], help="the camera yaw degree")
105
+ parser.add_argument('--camera_pitch', nargs='+', type=int, default=[0], help="the camera pitch degree")
106
+ parser.add_argument('--camera_roll', nargs='+', type=int, default=[0], help="the camera roll degree")
107
+ parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [GFPGAN]")
108
+ parser.add_argument("--cpu", dest="cpu", action="store_true")
109
+ parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
110
+ parser.add_argument("--still", action="store_true")
111
+
112
+ # net structure and parameters
113
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='not use')
114
+ parser.add_argument('--init_path', type=str, default=None, help='not Use')
115
+ parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
116
+ parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
117
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
118
+
119
+ # default renderer parameters
120
+ parser.add_argument('--focal', type=float, default=1015.)
121
+ parser.add_argument('--center', type=float, default=112.)
122
+ parser.add_argument('--camera_d', type=float, default=10.)
123
+ parser.add_argument('--z_near', type=float, default=5.)
124
+ parser.add_argument('--z_far', type=float, default=15.)
125
+
126
+ args = parser.parse_args()
127
+
128
+ if torch.cuda.is_available() and not args.cpu:
129
+ args.device = "cuda"
130
+ else:
131
+ args.device = "cpu"
132
+
133
+ main(args)
134
+
modules/__pycache__/gfpgan_inference.cpython-38.pyc ADDED
Binary file (1.36 kB). View file
modules/__pycache__/gfpgan_inference.cpython-39.pyc ADDED
Binary file (1.4 kB). View file
modules/__pycache__/sadtalker_test.cpython-38.pyc ADDED
Binary file (3.1 kB). View file
modules/__pycache__/sadtalker_test.cpython-39.pyc ADDED
Binary file (3.98 kB). View file
modules/__pycache__/text2speech.cpython-38.pyc ADDED
Binary file (473 Bytes). View file
modules/__pycache__/text2speech.cpython-39.pyc ADDED
Binary file (477 Bytes). View file
modules/gfpgan_inference.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+
3
+ def gfpgan(scale, origin_mp4_path):
4
+ current_code_path = sys.argv[0]
5
+ current_root_path = os.path.split(current_code_path)[0]
6
+ print(current_root_path)
7
+ gfpgan_code_path = current_root_path+'/repositories/GFPGAN/inference_gfpgan.py'
8
+ print(gfpgan_code_path)
9
+
10
+ #video2pic
11
+ result_dir = os.path.split(origin_mp4_path)[0]
12
+ video_name = os.path.split(origin_mp4_path)[1]
13
+ video_name = video_name.split('.')[0]
14
+ print(video_name)
15
+ str_scale = str(scale).replace('.', '_')
16
+ output_mp4_path = os.path.join(result_dir, video_name+'##'+str_scale+'.mp4')
17
+ temp_output_mp4_path = os.path.join(result_dir, 'temp_'+video_name+'##'+str_scale+'.mp4')
18
+
19
+ audio_name = video_name.split('##')[-1]
20
+ audio_path = os.path.join(result_dir, audio_name+'.wav')
21
+ temp_pic_dir1 = os.path.join(result_dir, video_name)
22
+ temp_pic_dir2 = os.path.join(result_dir, video_name+'##'+str_scale)
23
+ os.makedirs(temp_pic_dir1, exist_ok=True)
24
+ os.makedirs(temp_pic_dir2, exist_ok=True)
25
+ cmd1 = 'ffmpeg -i \"{}\" -start_number 0 \"{}\"/%06d.png -loglevel error -y'.format(origin_mp4_path, temp_pic_dir1)
26
+ os.system(cmd1)
27
+ cmd2 = f'python {gfpgan_code_path} -i {temp_pic_dir1} -o {temp_pic_dir2} -s {scale}'
28
+ os.system(cmd2)
29
+ cmd3 = f'ffmpeg -r 25 -f image2 -i {temp_pic_dir2}/%06d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {temp_output_mp4_path}'
30
+ os.system(cmd3)
31
+ cmd4 = f'ffmpeg -y -i {temp_output_mp4_path} -i {audio_path} -vcodec copy {output_mp4_path}'
32
+ os.system(cmd4)
33
+ #shutil.rmtree(temp_pic_dir1)
34
+ #shutil.rmtree(temp_pic_dir2)
35
+
36
+ return output_mp4_path
modules/sadtalker_test.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from time import gmtime, strftime
3
+ import os, sys, shutil
4
+ from argparse import ArgumentParser
5
+ from src.utils.preprocess import CropAndExtract
6
+ from src.test_audio2coeff import Audio2Coeff
7
+ from src.facerender.animate import AnimateFromCoeff
8
+ from src.generate_batch import get_data
9
+ from src.generate_facerender_batch import get_facerender_data
10
+
11
+ from modules.text2speech import text2speech
12
+
13
+ class SadTalker():
14
+
15
+ def __init__(self, checkpoint_path='checkpoints'):
16
+
17
+ if torch.cuda.is_available() :
18
+ device = "cuda"
19
+ else:
20
+ device = "cpu"
21
+
22
+ current_code_path = sys.argv[0]
23
+ modules_path = os.path.split(current_code_path)[0]
24
+
25
+ current_root_path = './'
26
+
27
+ os.environ['TORCH_HOME']=os.path.join(current_root_path, 'checkpoints')
28
+
29
+ path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat')
30
+ path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
31
+ dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting')
32
+ wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
33
+
34
+ audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
35
+ audio2pose_yaml_path = os.path.join(current_root_path, 'config', 'auido2pose.yaml')
36
+
37
+ audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
38
+ audio2exp_yaml_path = os.path.join(current_root_path, 'config', 'auido2exp.yaml')
39
+
40
+ free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar')
41
+ mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar')
42
+ facerender_yaml_path = os.path.join(current_root_path, 'config', 'facerender.yaml')
43
+
44
+ #init model
45
+ print(path_of_lm_croper)
46
+ self.preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
47
+
48
+ print(audio2pose_checkpoint)
49
+ self.audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
50
+ audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint, device)
51
+ print(free_view_checkpoint)
52
+ self.animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
53
+ facerender_yaml_path, device)
54
+ self.device = device
55
+
56
+ def test(self, source_image, driven_audio, result_dir):
57
+
58
+ time_tag = strftime("%Y_%m_%d_%H.%M.%S")
59
+ save_dir = os.path.join(result_dir, time_tag)
60
+ os.makedirs(save_dir, exist_ok=True)
61
+
62
+ input_dir = os.path.join(save_dir, 'input')
63
+ os.makedirs(input_dir, exist_ok=True)
64
+
65
+ print(source_image)
66
+ pic_path = os.path.join(input_dir, os.path.basename(source_image))
67
+ shutil.move(source_image, input_dir)
68
+
69
+ if os.path.isfile(driven_audio):
70
+ audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
71
+ shutil.move(driven_audio, input_dir)
72
+ else:
73
+ text2speech
74
+
75
+
76
+ os.makedirs(save_dir, exist_ok=True)
77
+ pose_style = 0
78
+ #crop image and extract 3dmm from image
79
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
80
+ os.makedirs(first_frame_dir, exist_ok=True)
81
+ first_coeff_path, crop_pic_path = self.preprocess_model.generate(pic_path, first_frame_dir)
82
+ if first_coeff_path is None:
83
+ raise AttributeError("No face is detected")
84
+
85
+ #audio2ceoff
86
+ batch = get_data(first_coeff_path, audio_path, self.device)
87
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
88
+ #coeff2video
89
+ batch_size = 4
90
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size)
91
+ self.animate_from_coeff.generate(data, save_dir)
92
+ video_name = data['video_name']
93
+ print(f'The generated video is named {video_name} in {save_dir}')
94
+ return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
95
+
modules/text2speech.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def text2speech(txt, audio_path):
4
+ print(txt)
5
+ cmd = f'tts --text "{txt}" --out_path {audio_path}'
6
+ print(cmd)
7
+ try:
8
+ os.system(cmd)
9
+ return audio_path
10
+ except:
11
+ print("Error: Failed convert txt to audio")
12
+ return None
packages.txt ADDED
@@ -0,0 +1 @@
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.4
2
+ face_alignment==1.3.5
3
+ imageio==2.19.3
4
+ imageio-ffmpeg==0.4.7
5
+ librosa==0.6.0
6
+ numba==0.48.0
7
+ resampy==0.3.1
8
+ pydub==0.25.1
9
+ scipy==1.5.3
10
+ kornia==0.6.8
11
+ tqdm
12
+ yacs==0.1.8
13
+ pyyaml
14
+ joblib==1.1.0
15
+ scikit-image==0.19.3
16
+ basicsr==1.4.2
17
+ facexlib==0.2.5
src/__pycache__/generate_batch.cpython-38.pyc ADDED
Binary file (2.81 kB). View file
src/__pycache__/generate_facerender_batch.cpython-38.pyc ADDED
Binary file (3.81 kB). View file
src/__pycache__/test_audio2coeff.cpython-38.pyc ADDED
Binary file (2.73 kB). View file
src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc ADDED
Binary file (1.07 kB). View file
src/audio2exp_models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (2.14 kB). View file
src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Audio2Exp(nn.Module):
6
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
7
+ super(Audio2Exp, self).__init__()
8
+ self.cfg = cfg
9
+ self.device = device
10
+ self.netG = netG.to(device)
11
+
12
+ def test(self, batch):
13
+
14
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
15
+ bs = mel_input.shape[0]
16
+ T = mel_input.shape[1]
17
+
18
+ ref = batch['ref'][:, :, :64].repeat((1,T,1)) #bs T 64
19
+ ratio = batch['ratio_gt'] #bs T
20
+
21
+ audiox = mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
22
+ exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
23
+
24
+ # BS x T x 64
25
+ results_dict = {
26
+ 'exp_coeff_pred': exp_coeff_pred
27
+ }
28
+ return results_dict
29
+
30
+
src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc ADDED
Binary file (2.94 kB). View file
src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc ADDED
Binary file (2.37 kB). View file
src/audio2pose_models/__pycache__/cvae.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc ADDED
Binary file (2.45 kB). View file
src/audio2pose_models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (4.74 kB). View file
src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc ADDED
Binary file (1.91 kB). View file
src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+ self.gan_criterion = nn.MSELoss()
24
+ self.reg_criterion = nn.L1Loss(reduction='none')
25
+ self.pair_criterion = nn.PairwiseDistance()
26
+ self.cosine_loss = nn.CosineSimilarity(dim=1)
27
+
28
+ def forward(self, x):
29
+
30
+ batch = {}
31
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
32
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
33
+ batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
34
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
35
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
36
+
37
+ # forward
38
+ audio_emb_list = []
39
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
40
+ batch['audio_emb'] = audio_emb
41
+ batch = self.netG(batch)
42
+
43
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
44
+ pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
45
+ pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
46
+
47
+ batch['pose_pred'] = pose_pred
48
+ batch['pose_gt'] = pose_gt
49
+
50
+ return batch
51
+
52
+ def test(self, x):
53
+
54
+ batch = {}
55
+ ref = x['ref'] #bs 1 70
56
+ batch['ref'] = x['ref'][:,0,-6:]
57
+ batch['class'] = x['class']
58
+ bs = ref.shape[0]
59
+
60
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
61
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
62
+ num_frames = x['num_frames']
63
+ num_frames = int(num_frames) - 1
64
+
65
+ #
66
+ div = num_frames//self.seq_len
67
+ re = num_frames%self.seq_len
68
+ audio_emb_list = []
69
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
70
+ device=batch['ref'].device)]
71
+
72
+ for i in range(div):
73
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
74
+ batch['z'] = z
75
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
76
+ batch['audio_emb'] = audio_emb
77
+ batch = self.netG.test(batch)
78
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
79
+ if re != 0:
80
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
81
+ batch['z'] = z
82
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
83
+ batch['audio_emb'] = audio_emb
84
+ batch = self.netG.test(batch)
85
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
86
+
87
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
88
+ batch['pose_motion_pred'] = pose_motion_pred
89
+
90
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
91
+
92
+ batch['pose_pred'] = pose_pred
93
+ return batch
src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder\
45
+ wav2lip_state_dict = torch.load(wav2lip_checkpoint)['state_dict']
46
+ state_dict = self.audio_encoder.state_dict()
47
+
48
+ for k,v in wav2lip_state_dict.items():
49
+ if 'audio_encoder' in k:
50
+ state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
src/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ResidualConv(nn.Module):
6
+ def __init__(self, input_dim, output_dim, stride, padding):
7
+ super(ResidualConv, self).__init__()
8
+
9
+ self.conv_block = nn.Sequential(
10
+ nn.BatchNorm2d(input_dim),
11
+ nn.ReLU(),
12
+ nn.Conv2d(
13
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
+ ),
15
+ nn.BatchNorm2d(output_dim),
16
+ nn.ReLU(),
17
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
+ )
19
+ self.conv_skip = nn.Sequential(
20
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
+ nn.BatchNorm2d(output_dim),
22
+ )
23
+
24
+ def forward(self, x):
25
+
26
+ return self.conv_block(x) + self.conv_skip(x)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, input_dim, output_dim, kernel, stride):
31
+ super(Upsample, self).__init__()
32
+
33
+ self.upsample = nn.ConvTranspose2d(
34
+ input_dim, output_dim, kernel_size=kernel, stride=stride
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.upsample(x)
39
+
40
+
41
+ class Squeeze_Excite_Block(nn.Module):
42
+ def __init__(self, channel, reduction=16):
43
+ super(Squeeze_Excite_Block, self).__init__()
44
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
+ self.fc = nn.Sequential(
46
+ nn.Linear(channel, channel // reduction, bias=False),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(channel // reduction, channel, bias=False),
49
+ nn.Sigmoid(),
50
+ )
51
+
52
+ def forward(self, x):
53
+ b, c, _, _ = x.size()
54
+ y = self.avg_pool(x).view(b, c)
55
+ y = self.fc(y).view(b, c, 1, 1)
56
+ return x * y.expand_as(x)
57
+
58
+
59
+ class ASPP(nn.Module):
60
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
+ super(ASPP, self).__init__()
62
+
63
+ self.aspp_block1 = nn.Sequential(
64
+ nn.Conv2d(
65
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
+ ),
67
+ nn.ReLU(inplace=True),
68
+ nn.BatchNorm2d(out_dims),
69
+ )
70
+ self.aspp_block2 = nn.Sequential(
71
+ nn.Conv2d(
72
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
+ ),
74
+ nn.ReLU(inplace=True),
75
+ nn.BatchNorm2d(out_dims),
76
+ )
77
+ self.aspp_block3 = nn.Sequential(
78
+ nn.Conv2d(
79
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
+ ),
81
+ nn.ReLU(inplace=True),
82
+ nn.BatchNorm2d(out_dims),
83
+ )
84
+
85
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
+ self._init_weights()
87
+
88
+ def forward(self, x):
89
+ x1 = self.aspp_block1(x)
90
+ x2 = self.aspp_block2(x)
91
+ x3 = self.aspp_block3(x)
92
+ out = torch.cat([x1, x2, x3], dim=1)
93
+ return self.output(out)
94
+
95
+ def _init_weights(self):
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class Upsample_(nn.Module):
105
+ def __init__(self, scale=2):
106
+ super(Upsample_, self).__init__()
107
+
108
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
+
110
+ def forward(self, x):
111
+ return self.upsample(x)
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+ def __init__(self, input_encoder, input_decoder, output_dim):
116
+ super(AttentionBlock, self).__init__()
117
+
118
+ self.conv_encoder = nn.Sequential(
119
+ nn.BatchNorm2d(input_encoder),
120
+ nn.ReLU(),
121
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
+ nn.MaxPool2d(2, 2),
123
+ )
124
+
125
+ self.conv_decoder = nn.Sequential(
126
+ nn.BatchNorm2d(input_decoder),
127
+ nn.ReLU(),
128
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
+ )
130
+
131
+ self.conv_attn = nn.Sequential(
132
+ nn.BatchNorm2d(output_dim),
133
+ nn.ReLU(),
134
+ nn.Conv2d(output_dim, 1, 1),
135
+ )
136
+
137
+ def forward(self, x1, x2):
138
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
+ out = self.conv_attn(out)
140
+ return out * x2
src/audio2pose_models/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.audio2pose_models.networks import ResidualConv, Upsample
4
+
5
+
6
+ class ResUnet(nn.Module):
7
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
+ super(ResUnet, self).__init__()
9
+
10
+ self.input_layer = nn.Sequential(
11
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(filters[0]),
13
+ nn.ReLU(),
14
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
+ )
16
+ self.input_skip = nn.Sequential(
17
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
+ )
19
+
20
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
+
23
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
+
25
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
+
28
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
+
31
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
+
34
+ self.output_layer = nn.Sequential(
35
+ nn.Conv2d(filters[0], 1, 1, 1),
36
+ nn.Sigmoid(),
37
+ )
38
+
39
+ def forward(self, x):
40
+ # Encode
41
+ x1 = self.input_layer(x) + self.input_skip(x)
42
+ x2 = self.residual_conv_1(x1)
43
+ x3 = self.residual_conv_2(x2)
44
+ # Bridge
45
+ x4 = self.bridge(x3)
46
+
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
src/config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
src/config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
src/config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc ADDED
Binary file (3.57 kB). View file
src/face3d/__pycache__/visualize.cpython-38.pyc ADDED
Binary file (1.7 kB). View file
src/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
src/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
src/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
src/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
src/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
src/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+ class KeypointExtractor():
15
+ def __init__(self):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D)
17
+
18
+ def extract_keypoint(self, images, name=None, info=True):
19
+ if isinstance(images, list):
20
+ keypoints = []
21
+ if info:
22
+ i_range = tqdm(images,desc='landmark Det:')
23
+ else:
24
+ i_range = images
25
+
26
+ for image in i_range:
27
+ current_kp = self.extract_keypoint(image)
28
+ if np.mean(current_kp) == -1 and keypoints:
29
+ keypoints.append(keypoints[-1])
30
+ else:
31
+ keypoints.append(current_kp[None])
32
+
33
+ keypoints = np.concatenate(keypoints, 0)
34
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
35
+ return keypoints
36
+ else:
37
+ while True:
38
+ try:
39
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
40
+ break
41
+ except RuntimeError as e:
42
+ if str(e).startswith('CUDA'):
43
+ print("Warning: out of memory, sleep for 1s")
44
+ time.sleep(1)
45
+ else:
46
+ print(e)
47
+ break
48
+ except TypeError:
49
+ print('No face detected in this image')
50
+ shape = [68, 2]
51
+ keypoints = -1. * np.ones(shape)
52
+ break
53
+ if name is not None:
54
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
55
+ return keypoints
56
+
57
+ def read_video(filename):
58
+ frames = []
59
+ cap = cv2.VideoCapture(filename)
60
+ while cap.isOpened():
61
+ ret, frame = cap.read()
62
+ if ret:
63
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
+ frame = Image.fromarray(frame)
65
+ frames.append(frame)
66
+ else:
67
+ break
68
+ cap.release()
69
+ return frames
70
+
71
+ def run(data):
72
+ filename, opt, device = data
73
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
74
+ kp_extractor = KeypointExtractor()
75
+ images = read_video(filename)
76
+ name = filename.split('/')[-2:]
77
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
78
+ kp_extractor.extract_keypoint(
79
+ images,
80
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
81
+ )
82
+
83
+ if __name__ == '__main__':
84
+ set_start_method('spawn')
85
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
86
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
87
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
88
+ parser.add_argument('--device_ids', type=str, default='0,1')
89
+ parser.add_argument('--workers', type=int, default=4)
90
+
91
+ opt = parser.parse_args()
92
+ filenames = list()
93
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
94
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
95
+ extensions = VIDEO_EXTENSIONS
96
+
97
+ for ext in extensions:
98
+ os.listdir(f'{opt.input_dir}')
99
+ print(f'{opt.input_dir}/*.{ext}')
100
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
101
+ print('Total number of videos:', len(filenames))
102
+ pool = Pool(opt.workers)
103
+ args_list = cycle([opt])
104
+ device_ids = opt.device_ids.split(",")
105
+ device_ids = cycle(device_ids)
106
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
107
+ None
src/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from src.face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance