vinthony commited on
Commit
a22eb82
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. LICENSE +21 -0
  3. README.md +194 -0
  4. app.py +86 -0
  5. config/auido2exp.yaml +58 -0
  6. config/auido2pose.yaml +49 -0
  7. config/facerender.yaml +45 -0
  8. inference.py +134 -0
  9. modules/__pycache__/gfpgan_inference.cpython-38.pyc +0 -0
  10. modules/__pycache__/gfpgan_inference.cpython-39.pyc +0 -0
  11. modules/__pycache__/sadtalker_test.cpython-38.pyc +0 -0
  12. modules/__pycache__/sadtalker_test.cpython-39.pyc +0 -0
  13. modules/__pycache__/text2speech.cpython-38.pyc +0 -0
  14. modules/__pycache__/text2speech.cpython-39.pyc +0 -0
  15. modules/gfpgan_inference.py +36 -0
  16. modules/sadtalker_test.py +95 -0
  17. modules/text2speech.py +12 -0
  18. packages.txt +1 -0
  19. requirements.txt +17 -0
  20. src/__pycache__/generate_batch.cpython-38.pyc +0 -0
  21. src/__pycache__/generate_facerender_batch.cpython-38.pyc +0 -0
  22. src/__pycache__/test_audio2coeff.cpython-38.pyc +0 -0
  23. src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
  24. src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
  25. src/audio2exp_models/audio2exp.py +30 -0
  26. src/audio2exp_models/networks.py +74 -0
  27. src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
  28. src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
  29. src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
  30. src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
  31. src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
  32. src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
  33. src/audio2pose_models/audio2pose.py +93 -0
  34. src/audio2pose_models/audio_encoder.py +64 -0
  35. src/audio2pose_models/cvae.py +149 -0
  36. src/audio2pose_models/discriminator.py +76 -0
  37. src/audio2pose_models/networks.py +140 -0
  38. src/audio2pose_models/res_unet.py +65 -0
  39. src/config/auido2exp.yaml +58 -0
  40. src/config/auido2pose.yaml +49 -0
  41. src/config/facerender.yaml +45 -0
  42. src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc +0 -0
  43. src/face3d/__pycache__/visualize.cpython-38.pyc +0 -0
  44. src/face3d/data/__init__.py +116 -0
  45. src/face3d/data/base_dataset.py +125 -0
  46. src/face3d/data/flist_dataset.py +125 -0
  47. src/face3d/data/image_folder.py +66 -0
  48. src/face3d/data/template_dataset.py +75 -0
  49. src/face3d/extract_kp_videos.py +107 -0
  50. src/face3d/models/__init__.py +67 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Tencent AI Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <h2> 😭 SadTalker: <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2>
4
+
5
+ <a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-2211.14758-red'></a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)
6
+
7
+ <div>
8
+ <a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a>&emsp;
9
+ <a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a>&emsp;
10
+ <a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a>&emsp;
11
+ <a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a>&emsp;
12
+ <a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>&emsp; </br>
13
+ <a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a>&emsp;
14
+ <a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a>&emsp;
15
+ <a target='_blank'>Fei Wang <sup>1</sup> </a>&emsp;
16
+ </div>
17
+ <br>
18
+ <div>
19
+ <sup>1</sup> Xi'an Jiaotong University &emsp; <sup>2</sup> Tencent AI Lab &emsp; <sup>3</sup> Ant Group &emsp;
20
+ </div>
21
+ <br>
22
+ <i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
23
+ <br>
24
+ <br>
25
+
26
+ ![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif)
27
+
28
+ <b>TL;DR: A realistic and stylized talking head video generation method from a single image and audio.</b>
29
+
30
+ <br>
31
+
32
+ </div>
33
+
34
+
35
+ ## 📋 Changelog
36
+
37
+
38
+ - __2023.03.22__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.
39
+
40
+ - __2023.03.22__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`.
41
+ - __2023.03.18__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.
42
+
43
+ - __2023.03.18__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
44
+ - __2023.03.18__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for better visualization performance.
45
+ - __2023.03.14__: Specify the version of package `joblib` to remove the errors in using `librosa`, [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
46
+ &nbsp;&nbsp;&nbsp;&nbsp; <details><summary> Previous Changelogs</summary>
47
+ - 2023.03.06 Solve some bugs in code and errors in installation
48
+ - 2023.03.03 Release the test code for audio-driven single image animation!
49
+ - 2023.02.28 SadTalker has been accepted by CVPR 2023!
50
+
51
+ </details>
52
+
53
+ ## 🎼 Pipeline
54
+ ![main_of_sadtalker](https://user-images.githubusercontent.com/4397546/222490596-4c8a2115-49a7-42ad-a2c3-3bb3288a5f36.png)
55
+
56
+
57
+ ## 🚧 TODO
58
+
59
+ - [x] Generating 2D face from a single Image.
60
+ - [x] Generating 3D face from Audio.
61
+ - [x] Generating 4D free-view talking examples from audio and a single image.
62
+ - [x] Gradio/Colab Demo.
63
+ - [ ] Full body/image Generation.
64
+ - [ ] training code of each componments.
65
+ - [ ] Audio-driven Anime Avatar.
66
+ - [ ] interpolate ChatGPT for a conversation demo 🤔
67
+ - [ ] integrade with stable-diffusion-web-ui. (stay tunning!)
68
+
69
+ https://user-images.githubusercontent.com/4397546/222513483-89161f58-83d0-40e4-8e41-96c32b47bd4e.mp4
70
+
71
+
72
+ ## 🔮 Inference Demo!
73
+
74
+ #### Dependence Installation
75
+
76
+ <details><summary>CLICK ME</summary>
77
+
78
+ ```
79
+ git clone https://github.com/Winfredy/SadTalker.git
80
+ cd SadTalker
81
+ conda create -n sadtalker python=3.8
82
+ source activate sadtalker
83
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
84
+ conda install ffmpeg
85
+ pip install dlib-bin # [dlib-bin is much faster than dlib installation] conda install dlib
86
+ pip install -r requirements.txt
87
+
88
+ ### install gpfgan for enhancer
89
+ pip install git+https://github.com/TencentARC/GFPGAN
90
+
91
+ ```
92
+
93
+ </details>
94
+
95
+ #### Trained Models
96
+ <details><summary>CLICK ME</summary>
97
+
98
+ You can run the following script to put all the models in the right place.
99
+
100
+ ```bash
101
+ bash scripts/download_models.sh
102
+ ```
103
+
104
+ OR download our pre-trained model from [google drive](https://drive.google.com/drive/folders/1Wd88VDoLhVzYsQ30_qDVluQr_Xm46yHT?usp=sharing) or our [github release page](https://github.com/Winfredy/SadTalker/releases/tag/v0.0.1), and then, put it in ./checkpoints.
105
+
106
+ | Model | Description
107
+ | :--- | :----------
108
+ |checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
109
+ |checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
110
+ |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
111
+ |checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
112
+ |checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
113
+ |checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
114
+ |checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/).
115
+ |checkpoints/BFM | 3DMM library file.
116
+ |checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
117
+
118
+ </details>
119
+
120
+ #### Generating 2D face from a single Image
121
+
122
+ ```bash
123
+ python inference.py --driven_audio <audio.wav> \
124
+ --source_image <video.mp4 or picture.png> \
125
+ --batch_size <default equals 2, a larger run faster> \
126
+ --expression_scale <default is 1.0, a larger value will make the motion stronger> \
127
+ --result_dir <a file to store results> \
128
+ --enhancer <default is None, you can choose gfpgan or RestoreFormer>
129
+ ```
130
+
131
+ <!-- ###### The effectness of enhancer `gfpgan`. -->
132
+
133
+ | basic | w/ still mode | w/ exp_scale 1.3 | w/ gfpgan |
134
+ |:-------------: |:-------------: |:-------------: |:-------------: |
135
+ | <video src="https://user-images.githubusercontent.com/4397546/226097707-bef1dd41-403e-48d3-a6e6-6adf923843af.mp4"></video> | <video src='https://user-images.githubusercontent.com/4397546/226804933-b717229f-1919-4bd5-b6af-bea7ab66cad3.mp4'></video> | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226806013-7752c308-8235-4e7a-9465-72d8fc1aa03d.mp4"></video> | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226097717-12a1a2a1-ac0f-428d-b2cb-bd6917aff73e.mp4"></video> |
136
+
137
+ > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
138
+
139
+
140
+ <!-- <video src="./docs/art_0##japanese_still.mp4"></video> -->
141
+
142
+
143
+ #### Generating 3D face from Audio
144
+
145
+
146
+ | Input | Animated 3d face |
147
+ |:-------------: | :-------------: |
148
+ | <img src='examples/source_image/art_0.png' width='200px'> | <video src="https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4"></video> |
149
+
150
+ > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
151
+
152
+ More details to generate the 3d face can be founded [here](docs/face3d.md)
153
+
154
+ #### Generating 4D free-view talking examples from audio and a single image
155
+
156
+ We use `camera_yaw`, `camera_pitch`, `camera_roll` to control camera pose. For example, `--camera_yaw -20 30 10` means the camera yaw degree changes from -20 to 30 and then changes from 30 to 10.
157
+ ```bash
158
+ python inference.py --driven_audio <audio.wav> \
159
+ --source_image <video.mp4 or picture.png> \
160
+ --result_dir <a file to store results> \
161
+ --camera_yaw -20 30 10
162
+ ```
163
+ ![free_view](https://github.com/Winfredy/SadTalker/blob/main/docs/free_view_result.gif)
164
+
165
+
166
+ ## 🛎 Citation
167
+
168
+ If you find our work useful in your research, please consider citing:
169
+
170
+ ```bibtex
171
+ @article{zhang2022sadtalker,
172
+ title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
173
+ author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
174
+ journal={arXiv preprint arXiv:2211.12194},
175
+ year={2022}
176
+ }
177
+ ```
178
+
179
+ ## 💗 Acknowledgements
180
+
181
+ Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, We also use the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.
182
+
183
+
184
+ ## 🥂 Related Works
185
+ - [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
186
+ - [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
187
+ - [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
188
+ - [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
189
+ - [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
190
+ - [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)
191
+
192
+ ## 📢 Disclaimer
193
+
194
+ This is not an official product of Tencent. This repository can only be used for personal/research/non-commercial purposes.
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import tempfile
3
+ import gradio as gr
4
+ from modules.text2speech import text2speech
5
+ from modules.gfpgan_inference import gfpgan
6
+ from modules.sadtalker_test import SadTalker
7
+
8
+ def get_driven_audio(audio):
9
+ if os.path.isfile(audio):
10
+ return audio
11
+ else:
12
+ save_path = tempfile.NamedTemporaryFile(
13
+ delete=False,
14
+ suffix=("." + "wav"),
15
+ )
16
+ gen_audio = text2speech(audio, save_path.name)
17
+ return gen_audio, gen_audio
18
+
19
+ def get_source_image(image):
20
+ return image
21
+
22
+ def sadtalker_demo(result_dir):
23
+
24
+ sad_talker = SadTalker()
25
+ with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
26
+ with gr.Row().style(equal_height=False):
27
+ with gr.Column(variant='panel'):
28
+ with gr.Tabs(elem_id="sadtalker_source_image"):
29
+ source_image = gr.Image(visible=False, type="filepath")
30
+ with gr.TabItem('Upload image'):
31
+ with gr.Row():
32
+ input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
33
+ submit_image = gr.Button('Submit', variant='primary')
34
+ submit_image.click(fn=get_source_image, inputs=input_image, outputs=source_image)
35
+
36
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
37
+ driven_audio = gr.Audio(visible=False, type="filepath")
38
+ with gr.TabItem('Upload audio'):
39
+ with gr.Column(variant='panel'):
40
+ input_audio1 = gr.Audio(label="Input audio", source="upload", type="filepath")
41
+ submit_audio_1 = gr.Button('Submit', variant='primary')
42
+ submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
43
+
44
+ with gr.TabItem('Microphone'):
45
+ with gr.Column(variant='panel'):
46
+ input_audio2 = gr.Audio(label="Recording audio", source="microphone", type="filepath")
47
+ submit_audio_2 = gr.Button('Submit', variant='primary')
48
+ submit_audio_2.click(fn=get_driven_audio, inputs=input_audio2, outputs=driven_audio)
49
+
50
+ with gr.TabItem('TTS'):
51
+ with gr.Column(variant='panel'):
52
+ with gr.Row().style(equal_height=False):
53
+ input_text = gr.Textbox(label="Input text", lines=5, value="Please enter some text in English")
54
+ input_audio3 = gr.Audio(label="Generated audio", type="filepath")
55
+ submit_audio_3 = gr.Button('Submit', variant='primary')
56
+ submit_audio_3.click(fn=get_driven_audio, inputs=input_text, outputs=[input_audio3, driven_audio])
57
+
58
+ with gr.Column(variant='panel'):
59
+ gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
60
+ gen_text = gr.Textbox(visible=False)
61
+ submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
62
+ scale = gr.Slider(minimum=1, maximum=8, step=1, label="GFPGAN scale", value=1)
63
+ new_video = gr.Video(label="New video", format="mp4").style(height=512,width=512)
64
+ change_scale = gr.Button('Restore video', elem_id="restore_video", variant='primary')
65
+
66
+ submit.click(
67
+ fn=sad_talker.test,
68
+ inputs=[source_image,
69
+ driven_audio,
70
+ gr.Textbox(value=result_dir, visible=False)],
71
+ outputs=[gen_video, gen_text]
72
+ )
73
+ change_scale.click(gfpgan, [scale, gen_text], new_video)
74
+
75
+ return sadtalker_interface
76
+
77
+
78
+ if __name__ == "__main__":
79
+
80
+ current_code_path = sys.argv[0]
81
+ current_root_dir = os.path.split(current_code_path)[0]
82
+ sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
83
+ demo = sadtalker_demo(sadtalker_result_dir)
84
+ demo.launch()
85
+
86
+
config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
inference.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from time import strftime
3
+ import os, sys, time
4
+ from argparse import ArgumentParser
5
+
6
+ from src.utils.preprocess import CropAndExtract
7
+ from src.test_audio2coeff import Audio2Coeff
8
+ from src.facerender.animate import AnimateFromCoeff
9
+ from src.generate_batch import get_data
10
+ from src.generate_facerender_batch import get_facerender_data
11
+
12
+ def main(args):
13
+ #torch.backends.cudnn.enabled = False
14
+
15
+ pic_path = args.source_image
16
+ audio_path = args.driven_audio
17
+ save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
18
+ os.makedirs(save_dir, exist_ok=True)
19
+ pose_style = args.pose_style
20
+ device = args.device
21
+ batch_size = args.batch_size
22
+ camera_yaw_list = args.camera_yaw
23
+ camera_pitch_list = args.camera_pitch
24
+ camera_roll_list = args.camera_roll
25
+
26
+ current_code_path = sys.argv[0]
27
+ current_root_path = os.path.split(current_code_path)[0]
28
+
29
+ os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)
30
+
31
+ path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
32
+ path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
33
+ dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
34
+ wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')
35
+
36
+ audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
37
+ audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
38
+
39
+ audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
40
+ audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
41
+
42
+ free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
43
+ mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
44
+ facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
45
+
46
+ #init model
47
+ print(path_of_net_recon_model)
48
+ preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
49
+
50
+ print(audio2pose_checkpoint)
51
+ print(audio2exp_checkpoint)
52
+ audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
53
+ audio2exp_checkpoint, audio2exp_yaml_path,
54
+ wav2lip_checkpoint, device)
55
+
56
+ print(free_view_checkpoint)
57
+ print(mapping_checkpoint)
58
+ animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
59
+ facerender_yaml_path, device)
60
+
61
+ #crop image and extract 3dmm from image
62
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
63
+ os.makedirs(first_frame_dir, exist_ok=True)
64
+ first_coeff_path, crop_pic_path = preprocess_model.generate(pic_path, first_frame_dir)
65
+ if first_coeff_path is None:
66
+ print("Can't get the coeffs of the input")
67
+ return
68
+
69
+ #audio2ceoff
70
+ batch = get_data(first_coeff_path, audio_path, device)
71
+ coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style)
72
+
73
+ # 3dface render
74
+ if args.face3dvis:
75
+ from src.face3d.visualize import gen_composed_video
76
+ gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
77
+
78
+ #coeff2video
79
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
80
+ batch_size, camera_yaw_list, camera_pitch_list, camera_roll_list,
81
+ expression_scale=args.expression_scale, still_mode=args.still)
82
+
83
+ animate_from_coeff.generate(data, save_dir, enhancer=args.enhancer)
84
+ video_name = data['video_name']
85
+
86
+ if args.enhancer is not None:
87
+ print(f'The generated video is named {video_name}_enhanced in {save_dir}')
88
+ else:
89
+ print(f'The generated video is named {video_name} in {save_dir}')
90
+
91
+ return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
92
+
93
+
94
+ if __name__ == '__main__':
95
+
96
+ parser = ArgumentParser()
97
+ parser.add_argument("--driven_audio", default='./examples/driven_audio/japanese.wav', help="path to driven audio")
98
+ parser.add_argument("--source_image", default='./examples/source_image/art_0.png', help="path to source image")
99
+ parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
100
+ parser.add_argument("--result_dir", default='./results', help="path to output")
101
+ parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
102
+ parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
103
+ parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
104
+ parser.add_argument('--camera_yaw', nargs='+', type=int, default=[0], help="the camera yaw degree")
105
+ parser.add_argument('--camera_pitch', nargs='+', type=int, default=[0], help="the camera pitch degree")
106
+ parser.add_argument('--camera_roll', nargs='+', type=int, default=[0], help="the camera roll degree")
107
+ parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [GFPGAN]")
108
+ parser.add_argument("--cpu", dest="cpu", action="store_true")
109
+ parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
110
+ parser.add_argument("--still", action="store_true")
111
+
112
+ # net structure and parameters
113
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='not use')
114
+ parser.add_argument('--init_path', type=str, default=None, help='not Use')
115
+ parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
116
+ parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
117
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
118
+
119
+ # default renderer parameters
120
+ parser.add_argument('--focal', type=float, default=1015.)
121
+ parser.add_argument('--center', type=float, default=112.)
122
+ parser.add_argument('--camera_d', type=float, default=10.)
123
+ parser.add_argument('--z_near', type=float, default=5.)
124
+ parser.add_argument('--z_far', type=float, default=15.)
125
+
126
+ args = parser.parse_args()
127
+
128
+ if torch.cuda.is_available() and not args.cpu:
129
+ args.device = "cuda"
130
+ else:
131
+ args.device = "cpu"
132
+
133
+ main(args)
134
+
modules/__pycache__/gfpgan_inference.cpython-38.pyc ADDED
Binary file (1.36 kB). View file
modules/__pycache__/gfpgan_inference.cpython-39.pyc ADDED
Binary file (1.4 kB). View file
modules/__pycache__/sadtalker_test.cpython-38.pyc ADDED
Binary file (3.1 kB). View file
modules/__pycache__/sadtalker_test.cpython-39.pyc ADDED
Binary file (3.98 kB). View file
modules/__pycache__/text2speech.cpython-38.pyc ADDED
Binary file (473 Bytes). View file
modules/__pycache__/text2speech.cpython-39.pyc ADDED
Binary file (477 Bytes). View file
modules/gfpgan_inference.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+
3
+ def gfpgan(scale, origin_mp4_path):
4
+ current_code_path = sys.argv[0]
5
+ current_root_path = os.path.split(current_code_path)[0]
6
+ print(current_root_path)
7
+ gfpgan_code_path = current_root_path+'/repositories/GFPGAN/inference_gfpgan.py'
8
+ print(gfpgan_code_path)
9
+
10
+ #video2pic
11
+ result_dir = os.path.split(origin_mp4_path)[0]
12
+ video_name = os.path.split(origin_mp4_path)[1]
13
+ video_name = video_name.split('.')[0]
14
+ print(video_name)
15
+ str_scale = str(scale).replace('.', '_')
16
+ output_mp4_path = os.path.join(result_dir, video_name+'##'+str_scale+'.mp4')
17
+ temp_output_mp4_path = os.path.join(result_dir, 'temp_'+video_name+'##'+str_scale+'.mp4')
18
+
19
+ audio_name = video_name.split('##')[-1]
20
+ audio_path = os.path.join(result_dir, audio_name+'.wav')
21
+ temp_pic_dir1 = os.path.join(result_dir, video_name)
22
+ temp_pic_dir2 = os.path.join(result_dir, video_name+'##'+str_scale)
23
+ os.makedirs(temp_pic_dir1, exist_ok=True)
24
+ os.makedirs(temp_pic_dir2, exist_ok=True)
25
+ cmd1 = 'ffmpeg -i \"{}\" -start_number 0 \"{}\"/%06d.png -loglevel error -y'.format(origin_mp4_path, temp_pic_dir1)
26
+ os.system(cmd1)
27
+ cmd2 = f'python {gfpgan_code_path} -i {temp_pic_dir1} -o {temp_pic_dir2} -s {scale}'
28
+ os.system(cmd2)
29
+ cmd3 = f'ffmpeg -r 25 -f image2 -i {temp_pic_dir2}/%06d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {temp_output_mp4_path}'
30
+ os.system(cmd3)
31
+ cmd4 = f'ffmpeg -y -i {temp_output_mp4_path} -i {audio_path} -vcodec copy {output_mp4_path}'
32
+ os.system(cmd4)
33
+ #shutil.rmtree(temp_pic_dir1)
34
+ #shutil.rmtree(temp_pic_dir2)
35
+
36
+ return output_mp4_path
modules/sadtalker_test.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from time import gmtime, strftime
3
+ import os, sys, shutil
4
+ from argparse import ArgumentParser
5
+ from src.utils.preprocess import CropAndExtract
6
+ from src.test_audio2coeff import Audio2Coeff
7
+ from src.facerender.animate import AnimateFromCoeff
8
+ from src.generate_batch import get_data
9
+ from src.generate_facerender_batch import get_facerender_data
10
+
11
+ from modules.text2speech import text2speech
12
+
13
+ class SadTalker():
14
+
15
+ def __init__(self, checkpoint_path='checkpoints'):
16
+
17
+ if torch.cuda.is_available() :
18
+ device = "cuda"
19
+ else:
20
+ device = "cpu"
21
+
22
+ current_code_path = sys.argv[0]
23
+ modules_path = os.path.split(current_code_path)[0]
24
+
25
+ current_root_path = './'
26
+
27
+ os.environ['TORCH_HOME']=os.path.join(current_root_path, 'checkpoints')
28
+
29
+ path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat')
30
+ path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
31
+ dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting')
32
+ wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
33
+
34
+ audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
35
+ audio2pose_yaml_path = os.path.join(current_root_path, 'config', 'auido2pose.yaml')
36
+
37
+ audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
38
+ audio2exp_yaml_path = os.path.join(current_root_path, 'config', 'auido2exp.yaml')
39
+
40
+ free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar')
41
+ mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar')
42
+ facerender_yaml_path = os.path.join(current_root_path, 'config', 'facerender.yaml')
43
+
44
+ #init model
45
+ print(path_of_lm_croper)
46
+ self.preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
47
+
48
+ print(audio2pose_checkpoint)
49
+ self.audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
50
+ audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint, device)
51
+ print(free_view_checkpoint)
52
+ self.animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
53
+ facerender_yaml_path, device)
54
+ self.device = device
55
+
56
+ def test(self, source_image, driven_audio, result_dir):
57
+
58
+ time_tag = strftime("%Y_%m_%d_%H.%M.%S")
59
+ save_dir = os.path.join(result_dir, time_tag)
60
+ os.makedirs(save_dir, exist_ok=True)
61
+
62
+ input_dir = os.path.join(save_dir, 'input')
63
+ os.makedirs(input_dir, exist_ok=True)
64
+
65
+ print(source_image)
66
+ pic_path = os.path.join(input_dir, os.path.basename(source_image))
67
+ shutil.move(source_image, input_dir)
68
+
69
+ if os.path.isfile(driven_audio):
70
+ audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
71
+ shutil.move(driven_audio, input_dir)
72
+ else:
73
+ text2speech
74
+
75
+
76
+ os.makedirs(save_dir, exist_ok=True)
77
+ pose_style = 0
78
+ #crop image and extract 3dmm from image
79
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
80
+ os.makedirs(first_frame_dir, exist_ok=True)
81
+ first_coeff_path, crop_pic_path = self.preprocess_model.generate(pic_path, first_frame_dir)
82
+ if first_coeff_path is None:
83
+ raise AttributeError("No face is detected")
84
+
85
+ #audio2ceoff
86
+ batch = get_data(first_coeff_path, audio_path, self.device)
87
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
88
+ #coeff2video
89
+ batch_size = 4
90
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size)
91
+ self.animate_from_coeff.generate(data, save_dir)
92
+ video_name = data['video_name']
93
+ print(f'The generated video is named {video_name} in {save_dir}')
94
+ return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
95
+
modules/text2speech.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def text2speech(txt, audio_path):
4
+ print(txt)
5
+ cmd = f'tts --text "{txt}" --out_path {audio_path}'
6
+ print(cmd)
7
+ try:
8
+ os.system(cmd)
9
+ return audio_path
10
+ except:
11
+ print("Error: Failed convert txt to audio")
12
+ return None
packages.txt ADDED
@@ -0,0 +1 @@
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.4
2
+ face_alignment==1.3.5
3
+ imageio==2.19.3
4
+ imageio-ffmpeg==0.4.7
5
+ librosa==0.6.0
6
+ numba==0.48.0
7
+ resampy==0.3.1
8
+ pydub==0.25.1
9
+ scipy==1.5.3
10
+ kornia==0.6.8
11
+ tqdm
12
+ yacs==0.1.8
13
+ pyyaml
14
+ joblib==1.1.0
15
+ scikit-image==0.19.3
16
+ basicsr==1.4.2
17
+ facexlib==0.2.5
src/__pycache__/generate_batch.cpython-38.pyc ADDED
Binary file (2.81 kB). View file
src/__pycache__/generate_facerender_batch.cpython-38.pyc ADDED
Binary file (3.81 kB). View file
src/__pycache__/test_audio2coeff.cpython-38.pyc ADDED
Binary file (2.73 kB). View file
src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc ADDED
Binary file (1.07 kB). View file
src/audio2exp_models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (2.14 kB). View file
src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Audio2Exp(nn.Module):
6
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
7
+ super(Audio2Exp, self).__init__()
8
+ self.cfg = cfg
9
+ self.device = device
10
+ self.netG = netG.to(device)
11
+
12
+ def test(self, batch):
13
+
14
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
15
+ bs = mel_input.shape[0]
16
+ T = mel_input.shape[1]
17
+
18
+ ref = batch['ref'][:, :, :64].repeat((1,T,1)) #bs T 64
19
+ ratio = batch['ratio_gt'] #bs T
20
+
21
+ audiox = mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
22
+ exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
23
+
24
+ # BS x T x 64
25
+ results_dict = {
26
+ 'exp_coeff_pred': exp_coeff_pred
27
+ }
28
+ return results_dict
29
+
30
+
src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc ADDED
Binary file (2.94 kB). View file
src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc ADDED
Binary file (2.37 kB). View file
src/audio2pose_models/__pycache__/cvae.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc ADDED
Binary file (2.45 kB). View file
src/audio2pose_models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (4.74 kB). View file
src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc ADDED
Binary file (1.91 kB). View file
src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+ self.gan_criterion = nn.MSELoss()
24
+ self.reg_criterion = nn.L1Loss(reduction='none')
25
+ self.pair_criterion = nn.PairwiseDistance()
26
+ self.cosine_loss = nn.CosineSimilarity(dim=1)
27
+
28
+ def forward(self, x):
29
+
30
+ batch = {}
31
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
32
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
33
+ batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
34
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
35
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
36
+
37
+ # forward
38
+ audio_emb_list = []
39
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
40
+ batch['audio_emb'] = audio_emb
41
+ batch = self.netG(batch)
42
+
43
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
44
+ pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
45
+ pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
46
+
47
+ batch['pose_pred'] = pose_pred
48
+ batch['pose_gt'] = pose_gt
49
+
50
+ return batch
51
+
52
+ def test(self, x):
53
+
54
+ batch = {}
55
+ ref = x['ref'] #bs 1 70
56
+ batch['ref'] = x['ref'][:,0,-6:]
57
+ batch['class'] = x['class']
58
+ bs = ref.shape[0]
59
+
60
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
61
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
62
+ num_frames = x['num_frames']
63
+ num_frames = int(num_frames) - 1
64
+
65
+ #
66
+ div = num_frames//self.seq_len
67
+ re = num_frames%self.seq_len
68
+ audio_emb_list = []
69
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
70
+ device=batch['ref'].device)]
71
+
72
+ for i in range(div):
73
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
74
+ batch['z'] = z
75
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
76
+ batch['audio_emb'] = audio_emb
77
+ batch = self.netG.test(batch)
78
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
79
+ if re != 0:
80
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
81
+ batch['z'] = z
82
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
83
+ batch['audio_emb'] = audio_emb
84
+ batch = self.netG.test(batch)
85
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
86
+
87
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
88
+ batch['pose_motion_pred'] = pose_motion_pred
89
+
90
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
91
+
92
+ batch['pose_pred'] = pose_pred
93
+ return batch
src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder\
45
+ wav2lip_state_dict = torch.load(wav2lip_checkpoint)['state_dict']
46
+ state_dict = self.audio_encoder.state_dict()
47
+
48
+ for k,v in wav2lip_state_dict.items():
49
+ if 'audio_encoder' in k:
50
+ state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
src/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@