Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						a22eb82
	
0
								Parent(s):
							
							
init
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +34 -0
 - LICENSE +21 -0
 - README.md +194 -0
 - app.py +86 -0
 - config/auido2exp.yaml +58 -0
 - config/auido2pose.yaml +49 -0
 - config/facerender.yaml +45 -0
 - inference.py +134 -0
 - modules/__pycache__/gfpgan_inference.cpython-38.pyc +0 -0
 - modules/__pycache__/gfpgan_inference.cpython-39.pyc +0 -0
 - modules/__pycache__/sadtalker_test.cpython-38.pyc +0 -0
 - modules/__pycache__/sadtalker_test.cpython-39.pyc +0 -0
 - modules/__pycache__/text2speech.cpython-38.pyc +0 -0
 - modules/__pycache__/text2speech.cpython-39.pyc +0 -0
 - modules/gfpgan_inference.py +36 -0
 - modules/sadtalker_test.py +95 -0
 - modules/text2speech.py +12 -0
 - packages.txt +1 -0
 - requirements.txt +17 -0
 - src/__pycache__/generate_batch.cpython-38.pyc +0 -0
 - src/__pycache__/generate_facerender_batch.cpython-38.pyc +0 -0
 - src/__pycache__/test_audio2coeff.cpython-38.pyc +0 -0
 - src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
 - src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
 - src/audio2exp_models/audio2exp.py +30 -0
 - src/audio2exp_models/networks.py +74 -0
 - src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
 - src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
 - src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
 - src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
 - src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
 - src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
 - src/audio2pose_models/audio2pose.py +93 -0
 - src/audio2pose_models/audio_encoder.py +64 -0
 - src/audio2pose_models/cvae.py +149 -0
 - src/audio2pose_models/discriminator.py +76 -0
 - src/audio2pose_models/networks.py +140 -0
 - src/audio2pose_models/res_unet.py +65 -0
 - src/config/auido2exp.yaml +58 -0
 - src/config/auido2pose.yaml +49 -0
 - src/config/facerender.yaml +45 -0
 - src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc +0 -0
 - src/face3d/__pycache__/visualize.cpython-38.pyc +0 -0
 - src/face3d/data/__init__.py +116 -0
 - src/face3d/data/base_dataset.py +125 -0
 - src/face3d/data/flist_dataset.py +125 -0
 - src/face3d/data/image_folder.py +66 -0
 - src/face3d/data/template_dataset.py +75 -0
 - src/face3d/extract_kp_videos.py +107 -0
 - src/face3d/models/__init__.py +67 -0
 
    	
        .gitattributes
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 4 | 
         
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 5 | 
         
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 6 | 
         
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 7 | 
         
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 8 | 
         
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 9 | 
         
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         
     | 
| 10 | 
         
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 11 | 
         
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         
     | 
| 12 | 
         
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         
     | 
| 13 | 
         
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         
     | 
| 14 | 
         
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 15 | 
         
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 16 | 
         
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         
     | 
| 17 | 
         
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         
     | 
| 18 | 
         
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         
     | 
| 19 | 
         
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         
     | 
| 20 | 
         
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         
     | 
| 21 | 
         
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 22 | 
         
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 23 | 
         
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         
     | 
| 24 | 
         
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 25 | 
         
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         
     | 
| 26 | 
         
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 27 | 
         
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 28 | 
         
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         
     | 
| 29 | 
         
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 30 | 
         
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         
     | 
| 31 | 
         
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 32 | 
         
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 33 | 
         
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,21 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            MIT License
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Copyright (c) 2023 Tencent AI Lab
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 6 | 
         
            +
            of this software and associated documentation files (the "Software"), to deal
         
     | 
| 7 | 
         
            +
            in the Software without restriction, including without limitation the rights
         
     | 
| 8 | 
         
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 9 | 
         
            +
            copies of the Software, and to permit persons to whom the Software is
         
     | 
| 10 | 
         
            +
            furnished to do so, subject to the following conditions:
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            The above copyright notice and this permission notice shall be included in all
         
     | 
| 13 | 
         
            +
            copies or substantial portions of the Software.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 16 | 
         
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 17 | 
         
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 18 | 
         
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 19 | 
         
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 20 | 
         
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 21 | 
         
            +
            SOFTWARE.
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,194 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <div align="center">
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            <h2> 😭 SadTalker: <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for  Stylized Audio-Driven Single Image Talking Face Animation </span> </h2> 
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
              <a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-2211.14758-red'></a>      <a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>       [](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) 
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            <div>
         
     | 
| 8 | 
         
            +
                <a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a> 
         
     | 
| 9 | 
         
            +
                <a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a> 
         
     | 
| 10 | 
         
            +
                <a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a> 
         
     | 
| 11 | 
         
            +
                <a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a> 
         
     | 
| 12 | 
         
            +
                <a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>  </br>
         
     | 
| 13 | 
         
            +
                <a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a> 
         
     | 
| 14 | 
         
            +
                <a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a> 
         
     | 
| 15 | 
         
            +
                <a target='_blank'>Fei Wang <sup>1</sup> </a> 
         
     | 
| 16 | 
         
            +
            </div>
         
     | 
| 17 | 
         
            +
            <br>
         
     | 
| 18 | 
         
            +
            <div>
         
     | 
| 19 | 
         
            +
                <sup>1</sup> Xi'an Jiaotong University   <sup>2</sup> Tencent AI Lab   <sup>3</sup> Ant Group   
         
     | 
| 20 | 
         
            +
            </div>
         
     | 
| 21 | 
         
            +
            <br>
         
     | 
| 22 | 
         
            +
            <i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
         
     | 
| 23 | 
         
            +
            <br>
         
     | 
| 24 | 
         
            +
            <br>
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            <b>TL;DR: A realistic and stylized talking head video generation method from a single image and audio.</b>
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            <br>
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            </div>
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            ## 📋 Changelog
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            - __2023.03.22__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            - __2023.03.22__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`. 
         
     | 
| 41 | 
         
            +
            - __2023.03.18__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            - __2023.03.18__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
         
     | 
| 44 | 
         
            +
            - __2023.03.18__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for  better visualization performance.
         
     | 
| 45 | 
         
            +
            - __2023.03.14__: Specify the version of package `joblib` to remove the errors in using `librosa`, [](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
         
     | 
| 46 | 
         
            +
                 <details><summary> Previous Changelogs</summary>
         
     | 
| 47 | 
         
            +
              - 2023.03.06 Solve some bugs in code and errors in installation 
         
     | 
| 48 | 
         
            +
              - 2023.03.03 Release the test code for audio-driven single image animation!
         
     | 
| 49 | 
         
            +
              - 2023.02.28 SadTalker has been accepted by CVPR 2023!
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            </details>
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            ## 🎼 Pipeline
         
     | 
| 54 | 
         
            +
             
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            ## 🚧 TODO
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            - [x] Generating 2D face from a single Image.
         
     | 
| 60 | 
         
            +
            - [x] Generating 3D face from Audio.
         
     | 
| 61 | 
         
            +
            - [x] Generating 4D free-view talking examples from audio and a single image.
         
     | 
| 62 | 
         
            +
            - [x] Gradio/Colab Demo.
         
     | 
| 63 | 
         
            +
            - [ ] Full body/image Generation.
         
     | 
| 64 | 
         
            +
            - [ ] training code of each componments.
         
     | 
| 65 | 
         
            +
            - [ ] Audio-driven Anime Avatar.
         
     | 
| 66 | 
         
            +
            - [ ] interpolate ChatGPT for a conversation demo 🤔
         
     | 
| 67 | 
         
            +
            - [ ] integrade with stable-diffusion-web-ui. (stay tunning!)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            https://user-images.githubusercontent.com/4397546/222513483-89161f58-83d0-40e4-8e41-96c32b47bd4e.mp4
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            ## 🔮 Inference Demo!
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            #### Dependence Installation
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            <details><summary>CLICK ME</summary>
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            ```
         
     | 
| 79 | 
         
            +
            git clone https://github.com/Winfredy/SadTalker.git
         
     | 
| 80 | 
         
            +
            cd SadTalker 
         
     | 
| 81 | 
         
            +
            conda create -n sadtalker python=3.8
         
     | 
| 82 | 
         
            +
            source activate sadtalker
         
     | 
| 83 | 
         
            +
            pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
         
     | 
| 84 | 
         
            +
            conda install ffmpeg
         
     | 
| 85 | 
         
            +
            pip install dlib-bin # [dlib-bin is much faster than dlib installation] conda install dlib 
         
     | 
| 86 | 
         
            +
            pip install -r requirements.txt
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            ### install gpfgan for enhancer
         
     | 
| 89 | 
         
            +
            pip install git+https://github.com/TencentARC/GFPGAN
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            ```  
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            </details>
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            #### Trained Models
         
     | 
| 96 | 
         
            +
            <details><summary>CLICK ME</summary>
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            You can run the following script to put all the models in the right place.
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            ```bash
         
     | 
| 101 | 
         
            +
            bash scripts/download_models.sh
         
     | 
| 102 | 
         
            +
            ```
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            OR download our pre-trained model from [google drive](https://drive.google.com/drive/folders/1Wd88VDoLhVzYsQ30_qDVluQr_Xm46yHT?usp=sharing) or our [github release page](https://github.com/Winfredy/SadTalker/releases/tag/v0.0.1), and then, put it in ./checkpoints.
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            | Model | Description
         
     | 
| 107 | 
         
            +
            | :--- | :----------
         
     | 
| 108 | 
         
            +
            |checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
         
     | 
| 109 | 
         
            +
            |checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
         
     | 
| 110 | 
         
            +
            |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
         
     | 
| 111 | 
         
            +
            |checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
         
     | 
| 112 | 
         
            +
            |checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
         
     | 
| 113 | 
         
            +
            |checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
         
     | 
| 114 | 
         
            +
            |checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/). 
         
     | 
| 115 | 
         
            +
            |checkpoints/BFM | 3DMM library file.  
         
     | 
| 116 | 
         
            +
            |checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            </details>
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            #### Generating 2D face from a single Image
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            ```bash
         
     | 
| 123 | 
         
            +
            python inference.py --driven_audio <audio.wav> \
         
     | 
| 124 | 
         
            +
                                --source_image <video.mp4 or picture.png> \
         
     | 
| 125 | 
         
            +
                                --batch_size <default equals 2, a larger run faster> \
         
     | 
| 126 | 
         
            +
                                --expression_scale <default is 1.0, a larger value will make the motion stronger> \
         
     | 
| 127 | 
         
            +
                                --result_dir <a file to store results> \
         
     | 
| 128 | 
         
            +
                                --enhancer <default is None, you can choose gfpgan or RestoreFormer>
         
     | 
| 129 | 
         
            +
            ```
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            <!-- ###### The effectness of enhancer `gfpgan`. -->
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            | basic        | w/ still mode |  w/ exp_scale 1.3   | w/ gfpgan  |
         
     | 
| 134 | 
         
            +
            |:-------------: |:-------------: |:-------------: |:-------------: |
         
     | 
| 135 | 
         
            +
            |  <video src="https://user-images.githubusercontent.com/4397546/226097707-bef1dd41-403e-48d3-a6e6-6adf923843af.mp4"></video>  | <video src='https://user-images.githubusercontent.com/4397546/226804933-b717229f-1919-4bd5-b6af-bea7ab66cad3.mp4'></video>  |  <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226806013-7752c308-8235-4e7a-9465-72d8fc1aa03d.mp4"></video>     | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226097717-12a1a2a1-ac0f-428d-b2cb-bd6917aff73e.mp4"></video>    |
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
            <!-- <video src="./docs/art_0##japanese_still.mp4"></video> -->
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            #### Generating 3D face from Audio
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            | Input        | Animated 3d face | 
         
     | 
| 147 | 
         
            +
            |:-------------: | :-------------: |
         
     | 
| 148 | 
         
            +
            |  <img src='examples/source_image/art_0.png' width='200px'> | <video src="https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4"></video>  | 
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            More details to generate the 3d face can be founded [here](docs/face3d.md)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            #### Generating 4D free-view talking examples from audio and a single image
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            We use `camera_yaw`, `camera_pitch`, `camera_roll` to control camera pose. For example, `--camera_yaw -20 30 10` means the camera yaw degree changes from -20 to 30 and then changes from 30 to 10.
         
     | 
| 157 | 
         
            +
            ```bash
         
     | 
| 158 | 
         
            +
            python inference.py --driven_audio <audio.wav> \
         
     | 
| 159 | 
         
            +
                                --source_image <video.mp4 or picture.png> \
         
     | 
| 160 | 
         
            +
                                --result_dir <a file to store results> \
         
     | 
| 161 | 
         
            +
                                --camera_yaw -20 30 10
         
     | 
| 162 | 
         
            +
            ```
         
     | 
| 163 | 
         
            +
            
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            ## 🛎 Citation
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
            If you find our work useful in your research, please consider citing:
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            ```bibtex
         
     | 
| 171 | 
         
            +
            @article{zhang2022sadtalker,
         
     | 
| 172 | 
         
            +
              title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
         
     | 
| 173 | 
         
            +
              author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
         
     | 
| 174 | 
         
            +
              journal={arXiv preprint arXiv:2211.12194},
         
     | 
| 175 | 
         
            +
              year={2022}
         
     | 
| 176 | 
         
            +
            }
         
     | 
| 177 | 
         
            +
            ```
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            ## 💗 Acknowledgements
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, We also use the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            ## 🥂 Related Works
         
     | 
| 185 | 
         
            +
            - [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
         
     | 
| 186 | 
         
            +
            - [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
         
     | 
| 187 | 
         
            +
            - [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
         
     | 
| 188 | 
         
            +
            - [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
         
     | 
| 189 | 
         
            +
            - [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
         
     | 
| 190 | 
         
            +
            - [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
            ## 📢 Disclaimer
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            This is not an official product of Tencent. This repository can only be used for personal/research/non-commercial purposes.
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,86 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os, sys
         
     | 
| 2 | 
         
            +
            import tempfile
         
     | 
| 3 | 
         
            +
            import gradio as gr
         
     | 
| 4 | 
         
            +
            from modules.text2speech import text2speech 
         
     | 
| 5 | 
         
            +
            from modules.gfpgan_inference import gfpgan
         
     | 
| 6 | 
         
            +
            from modules.sadtalker_test import SadTalker  
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            def get_driven_audio(audio):  
         
     | 
| 9 | 
         
            +
                if os.path.isfile(audio):
         
     | 
| 10 | 
         
            +
                    return audio
         
     | 
| 11 | 
         
            +
                else:
         
     | 
| 12 | 
         
            +
                    save_path = tempfile.NamedTemporaryFile(
         
     | 
| 13 | 
         
            +
                            delete=False,
         
     | 
| 14 | 
         
            +
                            suffix=("." + "wav"),
         
     | 
| 15 | 
         
            +
                        )
         
     | 
| 16 | 
         
            +
                    gen_audio = text2speech(audio, save_path.name)
         
     | 
| 17 | 
         
            +
                    return gen_audio, gen_audio 
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def get_source_image(image):   
         
     | 
| 20 | 
         
            +
                    return image
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def sadtalker_demo(result_dir):
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                sad_talker = SadTalker()
         
     | 
| 25 | 
         
            +
                with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
         
     | 
| 26 | 
         
            +
                    with gr.Row().style(equal_height=False):
         
     | 
| 27 | 
         
            +
                        with gr.Column(variant='panel'):
         
     | 
| 28 | 
         
            +
                            with gr.Tabs(elem_id="sadtalker_source_image"):
         
     | 
| 29 | 
         
            +
                                source_image = gr.Image(visible=False, type="filepath")
         
     | 
| 30 | 
         
            +
                                with gr.TabItem('Upload image'):
         
     | 
| 31 | 
         
            +
                                    with gr.Row():
         
     | 
| 32 | 
         
            +
                                        input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
         
     | 
| 33 | 
         
            +
                                    submit_image = gr.Button('Submit', variant='primary')
         
     | 
| 34 | 
         
            +
                                submit_image.click(fn=get_source_image, inputs=input_image, outputs=source_image)
         
     | 
| 35 | 
         
            +
             
         
     | 
| 36 | 
         
            +
                            with gr.Tabs(elem_id="sadtalker_driven_audio"):
         
     | 
| 37 | 
         
            +
                                driven_audio = gr.Audio(visible=False, type="filepath")
         
     | 
| 38 | 
         
            +
                                with gr.TabItem('Upload audio'):
         
     | 
| 39 | 
         
            +
                                    with gr.Column(variant='panel'):
         
     | 
| 40 | 
         
            +
                                        input_audio1 = gr.Audio(label="Input audio", source="upload", type="filepath")
         
     | 
| 41 | 
         
            +
                                        submit_audio_1 = gr.Button('Submit', variant='primary')
         
     | 
| 42 | 
         
            +
                                    submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                                with gr.TabItem('Microphone'):
         
     | 
| 45 | 
         
            +
                                    with gr.Column(variant='panel'):
         
     | 
| 46 | 
         
            +
                                        input_audio2 = gr.Audio(label="Recording audio", source="microphone", type="filepath")
         
     | 
| 47 | 
         
            +
                                        submit_audio_2 = gr.Button('Submit', variant='primary')
         
     | 
| 48 | 
         
            +
                                    submit_audio_2.click(fn=get_driven_audio, inputs=input_audio2, outputs=driven_audio)
         
     | 
| 49 | 
         
            +
                                
         
     | 
| 50 | 
         
            +
                                with gr.TabItem('TTS'):
         
     | 
| 51 | 
         
            +
                                    with gr.Column(variant='panel'):
         
     | 
| 52 | 
         
            +
                                        with gr.Row().style(equal_height=False):
         
     | 
| 53 | 
         
            +
                                            input_text = gr.Textbox(label="Input text", lines=5, value="Please enter some text in English")
         
     | 
| 54 | 
         
            +
                                            input_audio3 = gr.Audio(label="Generated audio",  type="filepath")
         
     | 
| 55 | 
         
            +
                                        submit_audio_3 = gr.Button('Submit', variant='primary')
         
     | 
| 56 | 
         
            +
                                    submit_audio_3.click(fn=get_driven_audio, inputs=input_text, outputs=[input_audio3, driven_audio])
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                        with gr.Column(variant='panel'): 
         
     | 
| 59 | 
         
            +
                            gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
         
     | 
| 60 | 
         
            +
                            gen_text = gr.Textbox(visible=False)
         
     | 
| 61 | 
         
            +
                            submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
         
     | 
| 62 | 
         
            +
                            scale = gr.Slider(minimum=1, maximum=8, step=1, label="GFPGAN scale", value=1) 
         
     | 
| 63 | 
         
            +
                            new_video = gr.Video(label="New video", format="mp4").style(height=512,width=512)
         
     | 
| 64 | 
         
            +
                            change_scale = gr.Button('Restore video', elem_id="restore_video", variant='primary')
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    submit.click(
         
     | 
| 67 | 
         
            +
                                fn=sad_talker.test, 
         
     | 
| 68 | 
         
            +
                                inputs=[source_image,
         
     | 
| 69 | 
         
            +
                                        driven_audio,
         
     | 
| 70 | 
         
            +
                                        gr.Textbox(value=result_dir, visible=False)], 
         
     | 
| 71 | 
         
            +
                                outputs=[gen_video, gen_text]
         
     | 
| 72 | 
         
            +
                                )
         
     | 
| 73 | 
         
            +
                    change_scale.click(gfpgan,  [scale, gen_text], new_video)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                return sadtalker_interface
         
     | 
| 76 | 
         
            +
             
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                current_code_path = sys.argv[0]
         
     | 
| 81 | 
         
            +
                current_root_dir = os.path.split(current_code_path)[0] 
         
     | 
| 82 | 
         
            +
                sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
         
     | 
| 83 | 
         
            +
                demo = sadtalker_demo(sadtalker_result_dir)
         
     | 
| 84 | 
         
            +
                demo.launch()
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
    	
        config/auido2exp.yaml
    ADDED
    
    | 
         @@ -0,0 +1,58 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            DATASET:
         
     | 
| 2 | 
         
            +
              TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
         
     | 
| 3 | 
         
            +
              EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
         
     | 
| 4 | 
         
            +
              TRAIN_BATCH_SIZE: 32
         
     | 
| 5 | 
         
            +
              EVAL_BATCH_SIZE: 32
         
     | 
| 6 | 
         
            +
              EXP: True
         
     | 
| 7 | 
         
            +
              EXP_DIM: 64
         
     | 
| 8 | 
         
            +
              FRAME_LEN: 32
         
     | 
| 9 | 
         
            +
              COEFF_LEN: 73
         
     | 
| 10 | 
         
            +
              NUM_CLASSES: 46
         
     | 
| 11 | 
         
            +
              AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
         
     | 
| 12 | 
         
            +
              COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
         
     | 
| 13 | 
         
            +
              LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
         
     | 
| 14 | 
         
            +
              DEBUG: True
         
     | 
| 15 | 
         
            +
              NUM_REPEATS: 2
         
     | 
| 16 | 
         
            +
              T: 40
         
     | 
| 17 | 
         
            +
              
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            MODEL:
         
     | 
| 20 | 
         
            +
              FRAMEWORK: V2
         
     | 
| 21 | 
         
            +
              AUDIOENCODER:
         
     | 
| 22 | 
         
            +
                LEAKY_RELU: True
         
     | 
| 23 | 
         
            +
                NORM: 'IN'
         
     | 
| 24 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 25 | 
         
            +
                LEAKY_RELU: False
         
     | 
| 26 | 
         
            +
                INPUT_CHANNELS: 6
         
     | 
| 27 | 
         
            +
              CVAE:
         
     | 
| 28 | 
         
            +
                AUDIO_EMB_IN_SIZE: 512
         
     | 
| 29 | 
         
            +
                AUDIO_EMB_OUT_SIZE: 128
         
     | 
| 30 | 
         
            +
                SEQ_LEN: 32
         
     | 
| 31 | 
         
            +
                LATENT_SIZE: 256
         
     | 
| 32 | 
         
            +
                ENCODER_LAYER_SIZES: [192, 1024]
         
     | 
| 33 | 
         
            +
                DECODER_LAYER_SIZES: [1024, 192]
         
     | 
| 34 | 
         
            +
                
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            TRAIN:
         
     | 
| 37 | 
         
            +
              MAX_EPOCH: 300
         
     | 
| 38 | 
         
            +
              GENERATOR:
         
     | 
| 39 | 
         
            +
                LR: 2.0e-5
         
     | 
| 40 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 41 | 
         
            +
                LR: 1.0e-5
         
     | 
| 42 | 
         
            +
              LOSS:
         
     | 
| 43 | 
         
            +
                W_FEAT: 0
         
     | 
| 44 | 
         
            +
                W_COEFF_EXP: 2
         
     | 
| 45 | 
         
            +
                W_LM: 1.0e-2
         
     | 
| 46 | 
         
            +
                W_LM_MOUTH: 0
         
     | 
| 47 | 
         
            +
                W_REG: 0
         
     | 
| 48 | 
         
            +
                W_SYNC: 0
         
     | 
| 49 | 
         
            +
                W_COLOR: 0
         
     | 
| 50 | 
         
            +
                W_EXPRESSION: 0
         
     | 
| 51 | 
         
            +
                W_LIPREADING: 0.01
         
     | 
| 52 | 
         
            +
                W_LIPREADING_VV: 0
         
     | 
| 53 | 
         
            +
                W_EYE_BLINK: 4
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            TAG:
         
     | 
| 56 | 
         
            +
              NAME:  small_dataset
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
    	
        config/auido2pose.yaml
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            DATASET:
         
     | 
| 2 | 
         
            +
              TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
         
     | 
| 3 | 
         
            +
              EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
         
     | 
| 4 | 
         
            +
              TRAIN_BATCH_SIZE: 64
         
     | 
| 5 | 
         
            +
              EVAL_BATCH_SIZE: 1
         
     | 
| 6 | 
         
            +
              EXP: True
         
     | 
| 7 | 
         
            +
              EXP_DIM: 64
         
     | 
| 8 | 
         
            +
              FRAME_LEN: 32
         
     | 
| 9 | 
         
            +
              COEFF_LEN: 73
         
     | 
| 10 | 
         
            +
              NUM_CLASSES: 46
         
     | 
| 11 | 
         
            +
              AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
         
     | 
| 12 | 
         
            +
              COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
         
     | 
| 13 | 
         
            +
              DEBUG: True
         
     | 
| 14 | 
         
            +
              
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            MODEL:
         
     | 
| 17 | 
         
            +
              AUDIOENCODER:
         
     | 
| 18 | 
         
            +
                LEAKY_RELU: True
         
     | 
| 19 | 
         
            +
                NORM: 'IN'
         
     | 
| 20 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 21 | 
         
            +
                LEAKY_RELU: False
         
     | 
| 22 | 
         
            +
                INPUT_CHANNELS: 6
         
     | 
| 23 | 
         
            +
              CVAE:
         
     | 
| 24 | 
         
            +
                AUDIO_EMB_IN_SIZE: 512
         
     | 
| 25 | 
         
            +
                AUDIO_EMB_OUT_SIZE: 6
         
     | 
| 26 | 
         
            +
                SEQ_LEN: 32
         
     | 
| 27 | 
         
            +
                LATENT_SIZE: 64
         
     | 
| 28 | 
         
            +
                ENCODER_LAYER_SIZES: [192, 128]
         
     | 
| 29 | 
         
            +
                DECODER_LAYER_SIZES: [128, 192]
         
     | 
| 30 | 
         
            +
                
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            TRAIN:
         
     | 
| 33 | 
         
            +
              MAX_EPOCH: 150
         
     | 
| 34 | 
         
            +
              GENERATOR:
         
     | 
| 35 | 
         
            +
                LR: 1.0e-4
         
     | 
| 36 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 37 | 
         
            +
                LR: 1.0e-4
         
     | 
| 38 | 
         
            +
              LOSS:
         
     | 
| 39 | 
         
            +
                LAMBDA_REG: 1
         
     | 
| 40 | 
         
            +
                LAMBDA_LANDMARKS: 0
         
     | 
| 41 | 
         
            +
                LAMBDA_VERTICES: 0
         
     | 
| 42 | 
         
            +
                LAMBDA_GAN_MOTION: 0.7
         
     | 
| 43 | 
         
            +
                LAMBDA_GAN_COEFF: 0
         
     | 
| 44 | 
         
            +
                LAMBDA_KL: 1
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            TAG:
         
     | 
| 47 | 
         
            +
              NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
    	
        config/facerender.yaml
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_params:
         
     | 
| 2 | 
         
            +
              common_params:
         
     | 
| 3 | 
         
            +
                num_kp: 15 
         
     | 
| 4 | 
         
            +
                image_channel: 3                    
         
     | 
| 5 | 
         
            +
                feature_channel: 32
         
     | 
| 6 | 
         
            +
                estimate_jacobian: False   # True
         
     | 
| 7 | 
         
            +
              kp_detector_params:
         
     | 
| 8 | 
         
            +
                 temperature: 0.1
         
     | 
| 9 | 
         
            +
                 block_expansion: 32            
         
     | 
| 10 | 
         
            +
                 max_features: 1024
         
     | 
| 11 | 
         
            +
                 scale_factor: 0.25         # 0.25
         
     | 
| 12 | 
         
            +
                 num_blocks: 5
         
     | 
| 13 | 
         
            +
                 reshape_channel: 16384  # 16384 = 1024 * 16
         
     | 
| 14 | 
         
            +
                 reshape_depth: 16
         
     | 
| 15 | 
         
            +
              he_estimator_params:
         
     | 
| 16 | 
         
            +
                 block_expansion: 64            
         
     | 
| 17 | 
         
            +
                 max_features: 2048
         
     | 
| 18 | 
         
            +
                 num_bins: 66
         
     | 
| 19 | 
         
            +
              generator_params:
         
     | 
| 20 | 
         
            +
                block_expansion: 64
         
     | 
| 21 | 
         
            +
                max_features: 512
         
     | 
| 22 | 
         
            +
                num_down_blocks: 2
         
     | 
| 23 | 
         
            +
                reshape_channel: 32
         
     | 
| 24 | 
         
            +
                reshape_depth: 16         # 512 = 32 * 16
         
     | 
| 25 | 
         
            +
                num_resblocks: 6
         
     | 
| 26 | 
         
            +
                estimate_occlusion_map: True
         
     | 
| 27 | 
         
            +
                dense_motion_params:
         
     | 
| 28 | 
         
            +
                  block_expansion: 32
         
     | 
| 29 | 
         
            +
                  max_features: 1024
         
     | 
| 30 | 
         
            +
                  num_blocks: 5
         
     | 
| 31 | 
         
            +
                  reshape_depth: 16
         
     | 
| 32 | 
         
            +
                  compress: 4
         
     | 
| 33 | 
         
            +
              discriminator_params:
         
     | 
| 34 | 
         
            +
                scales: [1]
         
     | 
| 35 | 
         
            +
                block_expansion: 32                 
         
     | 
| 36 | 
         
            +
                max_features: 512
         
     | 
| 37 | 
         
            +
                num_blocks: 4
         
     | 
| 38 | 
         
            +
                sn: True
         
     | 
| 39 | 
         
            +
              mapping_params:
         
     | 
| 40 | 
         
            +
                  coeff_nc: 70
         
     | 
| 41 | 
         
            +
                  descriptor_nc: 1024
         
     | 
| 42 | 
         
            +
                  layer: 3
         
     | 
| 43 | 
         
            +
                  num_kp: 15
         
     | 
| 44 | 
         
            +
                  num_bins: 66
         
     | 
| 45 | 
         
            +
             
     | 
    	
        inference.py
    ADDED
    
    | 
         @@ -0,0 +1,134 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from time import  strftime
         
     | 
| 3 | 
         
            +
            import os, sys, time
         
     | 
| 4 | 
         
            +
            from argparse import ArgumentParser
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from src.utils.preprocess import CropAndExtract
         
     | 
| 7 | 
         
            +
            from src.test_audio2coeff import Audio2Coeff  
         
     | 
| 8 | 
         
            +
            from src.facerender.animate import AnimateFromCoeff
         
     | 
| 9 | 
         
            +
            from src.generate_batch import get_data
         
     | 
| 10 | 
         
            +
            from src.generate_facerender_batch import get_facerender_data
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def main(args):
         
     | 
| 13 | 
         
            +
                #torch.backends.cudnn.enabled = False
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                pic_path = args.source_image
         
     | 
| 16 | 
         
            +
                audio_path = args.driven_audio
         
     | 
| 17 | 
         
            +
                save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
         
     | 
| 18 | 
         
            +
                os.makedirs(save_dir, exist_ok=True)
         
     | 
| 19 | 
         
            +
                pose_style = args.pose_style
         
     | 
| 20 | 
         
            +
                device = args.device
         
     | 
| 21 | 
         
            +
                batch_size = args.batch_size
         
     | 
| 22 | 
         
            +
                camera_yaw_list = args.camera_yaw
         
     | 
| 23 | 
         
            +
                camera_pitch_list = args.camera_pitch
         
     | 
| 24 | 
         
            +
                camera_roll_list = args.camera_roll
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                current_code_path = sys.argv[0]
         
     | 
| 27 | 
         
            +
                current_root_path = os.path.split(current_code_path)[0]
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
         
     | 
| 32 | 
         
            +
                path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
         
     | 
| 33 | 
         
            +
                dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
         
     | 
| 34 | 
         
            +
                wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
         
     | 
| 37 | 
         
            +
                audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
         
     | 
| 38 | 
         
            +
                
         
     | 
| 39 | 
         
            +
                audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
         
     | 
| 40 | 
         
            +
                audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
         
     | 
| 43 | 
         
            +
                mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
         
     | 
| 44 | 
         
            +
                facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                #init model
         
     | 
| 47 | 
         
            +
                print(path_of_net_recon_model)
         
     | 
| 48 | 
         
            +
                preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                print(audio2pose_checkpoint)
         
     | 
| 51 | 
         
            +
                print(audio2exp_checkpoint)
         
     | 
| 52 | 
         
            +
                audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, 
         
     | 
| 53 | 
         
            +
                                            audio2exp_checkpoint, audio2exp_yaml_path, 
         
     | 
| 54 | 
         
            +
                                            wav2lip_checkpoint, device)
         
     | 
| 55 | 
         
            +
                
         
     | 
| 56 | 
         
            +
                print(free_view_checkpoint)
         
     | 
| 57 | 
         
            +
                print(mapping_checkpoint)
         
     | 
| 58 | 
         
            +
                animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, 
         
     | 
| 59 | 
         
            +
                                                        facerender_yaml_path, device)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                #crop image and extract 3dmm from image
         
     | 
| 62 | 
         
            +
                first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
         
     | 
| 63 | 
         
            +
                os.makedirs(first_frame_dir, exist_ok=True)
         
     | 
| 64 | 
         
            +
                first_coeff_path, crop_pic_path =  preprocess_model.generate(pic_path, first_frame_dir)
         
     | 
| 65 | 
         
            +
                if first_coeff_path is None:
         
     | 
| 66 | 
         
            +
                    print("Can't get the coeffs of the input")
         
     | 
| 67 | 
         
            +
                    return
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                #audio2ceoff
         
     | 
| 70 | 
         
            +
                batch = get_data(first_coeff_path, audio_path, device)
         
     | 
| 71 | 
         
            +
                coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                # 3dface render
         
     | 
| 74 | 
         
            +
                if args.face3dvis:
         
     | 
| 75 | 
         
            +
                    from src.face3d.visualize import gen_composed_video
         
     | 
| 76 | 
         
            +
                    gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
         
     | 
| 77 | 
         
            +
                
         
     | 
| 78 | 
         
            +
                #coeff2video
         
     | 
| 79 | 
         
            +
                data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, 
         
     | 
| 80 | 
         
            +
                                            batch_size, camera_yaw_list, camera_pitch_list, camera_roll_list,
         
     | 
| 81 | 
         
            +
                                            expression_scale=args.expression_scale, still_mode=args.still)
         
     | 
| 82 | 
         
            +
                
         
     | 
| 83 | 
         
            +
                animate_from_coeff.generate(data, save_dir,  enhancer=args.enhancer)
         
     | 
| 84 | 
         
            +
                video_name = data['video_name']
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                if args.enhancer is not None:
         
     | 
| 87 | 
         
            +
                    print(f'The generated video is named {video_name}_enhanced in {save_dir}')
         
     | 
| 88 | 
         
            +
                else:
         
     | 
| 89 | 
         
            +
                    print(f'The generated video is named {video_name} in {save_dir}')
         
     | 
| 90 | 
         
            +
                
         
     | 
| 91 | 
         
            +
                return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                parser = ArgumentParser()  
         
     | 
| 97 | 
         
            +
                parser.add_argument("--driven_audio", default='./examples/driven_audio/japanese.wav', help="path to driven audio")
         
     | 
| 98 | 
         
            +
                parser.add_argument("--source_image", default='./examples/source_image/art_0.png', help="path to source image")
         
     | 
| 99 | 
         
            +
                parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
         
     | 
| 100 | 
         
            +
                parser.add_argument("--result_dir", default='./results', help="path to output")
         
     | 
| 101 | 
         
            +
                parser.add_argument("--pose_style", type=int, default=0,  help="input pose style from [0, 46)")
         
     | 
| 102 | 
         
            +
                parser.add_argument("--batch_size", type=int, default=2,  help="the batch size of facerender")
         
     | 
| 103 | 
         
            +
                parser.add_argument("--expression_scale", type=float, default=1.,  help="the batch size of facerender")
         
     | 
| 104 | 
         
            +
                parser.add_argument('--camera_yaw', nargs='+', type=int, default=[0], help="the camera yaw degree")
         
     | 
| 105 | 
         
            +
                parser.add_argument('--camera_pitch', nargs='+', type=int, default=[0], help="the camera pitch degree")
         
     | 
| 106 | 
         
            +
                parser.add_argument('--camera_roll', nargs='+', type=int, default=[0], help="the camera roll degree")
         
     | 
| 107 | 
         
            +
                parser.add_argument('--enhancer',  type=str, default=None, help="Face enhancer, [GFPGAN]")
         
     | 
| 108 | 
         
            +
                parser.add_argument("--cpu", dest="cpu", action="store_true") 
         
     | 
| 109 | 
         
            +
                parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") 
         
     | 
| 110 | 
         
            +
                parser.add_argument("--still", action="store_true") 
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                # net structure and parameters
         
     | 
| 113 | 
         
            +
                parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='not use')
         
     | 
| 114 | 
         
            +
                parser.add_argument('--init_path', type=str, default=None, help='not Use')
         
     | 
| 115 | 
         
            +
                parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
         
     | 
| 116 | 
         
            +
                parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
         
     | 
| 117 | 
         
            +
                parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                # default renderer parameters
         
     | 
| 120 | 
         
            +
                parser.add_argument('--focal', type=float, default=1015.)
         
     | 
| 121 | 
         
            +
                parser.add_argument('--center', type=float, default=112.)
         
     | 
| 122 | 
         
            +
                parser.add_argument('--camera_d', type=float, default=10.)
         
     | 
| 123 | 
         
            +
                parser.add_argument('--z_near', type=float, default=5.)
         
     | 
| 124 | 
         
            +
                parser.add_argument('--z_far', type=float, default=15.)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                args = parser.parse_args()
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                if torch.cuda.is_available() and not args.cpu:
         
     | 
| 129 | 
         
            +
                    args.device = "cuda"
         
     | 
| 130 | 
         
            +
                else:
         
     | 
| 131 | 
         
            +
                    args.device = "cpu"
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                main(args)
         
     | 
| 134 | 
         
            +
             
     | 
    	
        modules/__pycache__/gfpgan_inference.cpython-38.pyc
    ADDED
    
    | 
         Binary file (1.36 kB). View file 
     | 
| 
         | 
    	
        modules/__pycache__/gfpgan_inference.cpython-39.pyc
    ADDED
    
    | 
         Binary file (1.4 kB). View file 
     | 
| 
         | 
    	
        modules/__pycache__/sadtalker_test.cpython-38.pyc
    ADDED
    
    | 
         Binary file (3.1 kB). View file 
     | 
| 
         | 
    	
        modules/__pycache__/sadtalker_test.cpython-39.pyc
    ADDED
    
    | 
         Binary file (3.98 kB). View file 
     | 
| 
         | 
    	
        modules/__pycache__/text2speech.cpython-38.pyc
    ADDED
    
    | 
         Binary file (473 Bytes). View file 
     | 
| 
         | 
    	
        modules/__pycache__/text2speech.cpython-39.pyc
    ADDED
    
    | 
         Binary file (477 Bytes). View file 
     | 
| 
         | 
    	
        modules/gfpgan_inference.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os,sys
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            def gfpgan(scale, origin_mp4_path):
         
     | 
| 4 | 
         
            +
                current_code_path = sys.argv[0]
         
     | 
| 5 | 
         
            +
                current_root_path = os.path.split(current_code_path)[0]
         
     | 
| 6 | 
         
            +
                print(current_root_path)
         
     | 
| 7 | 
         
            +
                gfpgan_code_path = current_root_path+'/repositories/GFPGAN/inference_gfpgan.py'
         
     | 
| 8 | 
         
            +
                print(gfpgan_code_path)
         
     | 
| 9 | 
         
            +
                
         
     | 
| 10 | 
         
            +
                #video2pic
         
     | 
| 11 | 
         
            +
                result_dir = os.path.split(origin_mp4_path)[0]
         
     | 
| 12 | 
         
            +
                video_name = os.path.split(origin_mp4_path)[1]
         
     | 
| 13 | 
         
            +
                video_name = video_name.split('.')[0]
         
     | 
| 14 | 
         
            +
                print(video_name)
         
     | 
| 15 | 
         
            +
                str_scale = str(scale).replace('.', '_')
         
     | 
| 16 | 
         
            +
                output_mp4_path = os.path.join(result_dir, video_name+'##'+str_scale+'.mp4')
         
     | 
| 17 | 
         
            +
                temp_output_mp4_path = os.path.join(result_dir, 'temp_'+video_name+'##'+str_scale+'.mp4')
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                audio_name = video_name.split('##')[-1]
         
     | 
| 20 | 
         
            +
                audio_path = os.path.join(result_dir, audio_name+'.wav')
         
     | 
| 21 | 
         
            +
                temp_pic_dir1 = os.path.join(result_dir, video_name)
         
     | 
| 22 | 
         
            +
                temp_pic_dir2 = os.path.join(result_dir, video_name+'##'+str_scale)
         
     | 
| 23 | 
         
            +
                os.makedirs(temp_pic_dir1, exist_ok=True)
         
     | 
| 24 | 
         
            +
                os.makedirs(temp_pic_dir2, exist_ok=True)
         
     | 
| 25 | 
         
            +
                cmd1 = 'ffmpeg -i \"{}\" -start_number 0  \"{}\"/%06d.png -loglevel error -y'.format(origin_mp4_path, temp_pic_dir1)
         
     | 
| 26 | 
         
            +
                os.system(cmd1)
         
     | 
| 27 | 
         
            +
                cmd2 = f'python {gfpgan_code_path} -i {temp_pic_dir1} -o {temp_pic_dir2} -s {scale}'
         
     | 
| 28 | 
         
            +
                os.system(cmd2)
         
     | 
| 29 | 
         
            +
                cmd3 = f'ffmpeg -r 25 -f image2 -i {temp_pic_dir2}/%06d.png  -vcodec libx264 -crf 25  -pix_fmt yuv420p {temp_output_mp4_path}'
         
     | 
| 30 | 
         
            +
                os.system(cmd3)
         
     | 
| 31 | 
         
            +
                cmd4 = f'ffmpeg -y -i {temp_output_mp4_path}  -i {audio_path}  -vcodec copy {output_mp4_path}'
         
     | 
| 32 | 
         
            +
                os.system(cmd4)
         
     | 
| 33 | 
         
            +
                #shutil.rmtree(temp_pic_dir1)
         
     | 
| 34 | 
         
            +
                #shutil.rmtree(temp_pic_dir2)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                return output_mp4_path
         
     | 
    	
        modules/sadtalker_test.py
    ADDED
    
    | 
         @@ -0,0 +1,95 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from time import gmtime, strftime
         
     | 
| 3 | 
         
            +
            import os, sys, shutil
         
     | 
| 4 | 
         
            +
            from argparse import ArgumentParser
         
     | 
| 5 | 
         
            +
            from src.utils.preprocess import CropAndExtract
         
     | 
| 6 | 
         
            +
            from src.test_audio2coeff import Audio2Coeff  
         
     | 
| 7 | 
         
            +
            from src.facerender.animate import AnimateFromCoeff
         
     | 
| 8 | 
         
            +
            from src.generate_batch import get_data
         
     | 
| 9 | 
         
            +
            from src.generate_facerender_batch import get_facerender_data
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from modules.text2speech import text2speech
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class SadTalker():
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def __init__(self, checkpoint_path='checkpoints'):
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    if torch.cuda.is_available() :
         
     | 
| 18 | 
         
            +
                        device = "cuda"
         
     | 
| 19 | 
         
            +
                    else:
         
     | 
| 20 | 
         
            +
                        device = "cpu"
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    current_code_path = sys.argv[0]
         
     | 
| 23 | 
         
            +
                    modules_path = os.path.split(current_code_path)[0]
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    current_root_path = './'
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    os.environ['TORCH_HOME']=os.path.join(current_root_path, 'checkpoints')
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat')
         
     | 
| 30 | 
         
            +
                    path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
         
     | 
| 31 | 
         
            +
                    dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting')
         
     | 
| 32 | 
         
            +
                    wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
         
     | 
| 35 | 
         
            +
                    audio2pose_yaml_path = os.path.join(current_root_path, 'config', 'auido2pose.yaml')
         
     | 
| 36 | 
         
            +
                
         
     | 
| 37 | 
         
            +
                    audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
         
     | 
| 38 | 
         
            +
                    audio2exp_yaml_path = os.path.join(current_root_path, 'config', 'auido2exp.yaml')
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar')
         
     | 
| 41 | 
         
            +
                    mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar')
         
     | 
| 42 | 
         
            +
                    facerender_yaml_path = os.path.join(current_root_path, 'config', 'facerender.yaml')
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    #init model
         
     | 
| 45 | 
         
            +
                    print(path_of_lm_croper)
         
     | 
| 46 | 
         
            +
                    self.preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    print(audio2pose_checkpoint)
         
     | 
| 49 | 
         
            +
                    self.audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, 
         
     | 
| 50 | 
         
            +
                                            audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint, device)
         
     | 
| 51 | 
         
            +
                    print(free_view_checkpoint)
         
     | 
| 52 | 
         
            +
                    self.animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, 
         
     | 
| 53 | 
         
            +
                                                        facerender_yaml_path, device)
         
     | 
| 54 | 
         
            +
                    self.device = device
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def test(self, source_image, driven_audio, result_dir):
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    time_tag = strftime("%Y_%m_%d_%H.%M.%S")
         
     | 
| 59 | 
         
            +
                    save_dir = os.path.join(result_dir, time_tag)
         
     | 
| 60 | 
         
            +
                    os.makedirs(save_dir, exist_ok=True)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    input_dir = os.path.join(save_dir, 'input')
         
     | 
| 63 | 
         
            +
                    os.makedirs(input_dir, exist_ok=True)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    print(source_image)
         
     | 
| 66 | 
         
            +
                    pic_path = os.path.join(input_dir, os.path.basename(source_image)) 
         
     | 
| 67 | 
         
            +
                    shutil.move(source_image, input_dir)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    if os.path.isfile(driven_audio):
         
     | 
| 70 | 
         
            +
                        audio_path = os.path.join(input_dir, os.path.basename(driven_audio))  
         
     | 
| 71 | 
         
            +
                        shutil.move(driven_audio, input_dir)
         
     | 
| 72 | 
         
            +
                    else:
         
     | 
| 73 | 
         
            +
                        text2speech
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    os.makedirs(save_dir, exist_ok=True)
         
     | 
| 77 | 
         
            +
                    pose_style = 0
         
     | 
| 78 | 
         
            +
                    #crop image and extract 3dmm from image
         
     | 
| 79 | 
         
            +
                    first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
         
     | 
| 80 | 
         
            +
                    os.makedirs(first_frame_dir, exist_ok=True)
         
     | 
| 81 | 
         
            +
                    first_coeff_path, crop_pic_path = self.preprocess_model.generate(pic_path, first_frame_dir)
         
     | 
| 82 | 
         
            +
                    if first_coeff_path is None:
         
     | 
| 83 | 
         
            +
                        raise AttributeError("No face is detected")
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    #audio2ceoff
         
     | 
| 86 | 
         
            +
                    batch = get_data(first_coeff_path, audio_path, self.device)
         
     | 
| 87 | 
         
            +
                    coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
         
     | 
| 88 | 
         
            +
                    #coeff2video
         
     | 
| 89 | 
         
            +
                    batch_size = 4
         
     | 
| 90 | 
         
            +
                    data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size)
         
     | 
| 91 | 
         
            +
                    self.animate_from_coeff.generate(data, save_dir)
         
     | 
| 92 | 
         
            +
                    video_name = data['video_name']
         
     | 
| 93 | 
         
            +
                    print(f'The generated video is named {video_name} in {save_dir}')
         
     | 
| 94 | 
         
            +
                    return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
         
     | 
| 95 | 
         
            +
                
         
     | 
    	
        modules/text2speech.py
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            def text2speech(txt, audio_path):
         
     | 
| 4 | 
         
            +
                print(txt)
         
     | 
| 5 | 
         
            +
                cmd = f'tts --text "{txt}" --out_path {audio_path}'
         
     | 
| 6 | 
         
            +
                print(cmd)
         
     | 
| 7 | 
         
            +
                try:
         
     | 
| 8 | 
         
            +
                    os.system(cmd)
         
     | 
| 9 | 
         
            +
                    return audio_path
         
     | 
| 10 | 
         
            +
                except:
         
     | 
| 11 | 
         
            +
                    print("Error: Failed convert txt to audio")
         
     | 
| 12 | 
         
            +
                    return None
         
     | 
    	
        packages.txt
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ffmpeg
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,17 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            numpy==1.23.4
         
     | 
| 2 | 
         
            +
            face_alignment==1.3.5
         
     | 
| 3 | 
         
            +
            imageio==2.19.3
         
     | 
| 4 | 
         
            +
            imageio-ffmpeg==0.4.7
         
     | 
| 5 | 
         
            +
            librosa==0.6.0
         
     | 
| 6 | 
         
            +
            numba==0.48.0
         
     | 
| 7 | 
         
            +
            resampy==0.3.1
         
     | 
| 8 | 
         
            +
            pydub==0.25.1 
         
     | 
| 9 | 
         
            +
            scipy==1.5.3
         
     | 
| 10 | 
         
            +
            kornia==0.6.8
         
     | 
| 11 | 
         
            +
            tqdm
         
     | 
| 12 | 
         
            +
            yacs==0.1.8
         
     | 
| 13 | 
         
            +
            pyyaml  
         
     | 
| 14 | 
         
            +
            joblib==1.1.0
         
     | 
| 15 | 
         
            +
            scikit-image==0.19.3
         
     | 
| 16 | 
         
            +
            basicsr==1.4.2
         
     | 
| 17 | 
         
            +
            facexlib==0.2.5
         
     | 
    	
        src/__pycache__/generate_batch.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.81 kB). View file 
     | 
| 
         | 
    	
        src/__pycache__/generate_facerender_batch.cpython-38.pyc
    ADDED
    
    | 
         Binary file (3.81 kB). View file 
     | 
| 
         | 
    	
        src/__pycache__/test_audio2coeff.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.73 kB). View file 
     | 
| 
         | 
    	
        src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc
    ADDED
    
    | 
         Binary file (1.07 kB). View file 
     | 
| 
         | 
    	
        src/audio2exp_models/__pycache__/networks.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.14 kB). View file 
     | 
| 
         | 
    	
        src/audio2exp_models/audio2exp.py
    ADDED
    
    | 
         @@ -0,0 +1,30 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch import nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class Audio2Exp(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, netG, cfg, device, prepare_training_loss=False):
         
     | 
| 7 | 
         
            +
                    super(Audio2Exp, self).__init__()
         
     | 
| 8 | 
         
            +
                    self.cfg = cfg
         
     | 
| 9 | 
         
            +
                    self.device = device
         
     | 
| 10 | 
         
            +
                    self.netG = netG.to(device)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def test(self, batch):
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                    mel_input = batch['indiv_mels']                         # bs T 1 80 16
         
     | 
| 15 | 
         
            +
                    bs = mel_input.shape[0]
         
     | 
| 16 | 
         
            +
                    T = mel_input.shape[1]
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    ref = batch['ref'][:, :, :64].repeat((1,T,1))           #bs T 64
         
     | 
| 19 | 
         
            +
                    ratio = batch['ratio_gt']                               #bs T
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    audiox = mel_input.view(-1, 1, 80, 16)                  # bs*T 1 80 16
         
     | 
| 22 | 
         
            +
                    exp_coeff_pred  = self.netG(audiox, ref, ratio)         # bs T 64 
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    # BS x T x 64
         
     | 
| 25 | 
         
            +
                    results_dict = {
         
     | 
| 26 | 
         
            +
                        'exp_coeff_pred': exp_coeff_pred
         
     | 
| 27 | 
         
            +
                        }
         
     | 
| 28 | 
         
            +
                    return results_dict
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
    	
        src/audio2exp_models/networks.py
    ADDED
    
    | 
         @@ -0,0 +1,74 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            +
            from torch import nn
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class Conv2d(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
         
     | 
| 7 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 8 | 
         
            +
                    self.conv_block = nn.Sequential(
         
     | 
| 9 | 
         
            +
                                        nn.Conv2d(cin, cout, kernel_size, stride, padding),
         
     | 
| 10 | 
         
            +
                                        nn.BatchNorm2d(cout)
         
     | 
| 11 | 
         
            +
                                        )
         
     | 
| 12 | 
         
            +
                    self.act = nn.ReLU()
         
     | 
| 13 | 
         
            +
                    self.residual = residual
         
     | 
| 14 | 
         
            +
                    self.use_act = use_act
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def forward(self, x):
         
     | 
| 17 | 
         
            +
                    out = self.conv_block(x)
         
     | 
| 18 | 
         
            +
                    if self.residual:
         
     | 
| 19 | 
         
            +
                        out += x
         
     | 
| 20 | 
         
            +
                    
         
     | 
| 21 | 
         
            +
                    if self.use_act:
         
     | 
| 22 | 
         
            +
                        return self.act(out)
         
     | 
| 23 | 
         
            +
                    else:
         
     | 
| 24 | 
         
            +
                        return out
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            class SimpleWrapperV2(nn.Module):
         
     | 
| 27 | 
         
            +
                def __init__(self) -> None:
         
     | 
| 28 | 
         
            +
                    super().__init__()
         
     | 
| 29 | 
         
            +
                    self.audio_encoder = nn.Sequential(
         
     | 
| 30 | 
         
            +
                        Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 31 | 
         
            +
                        Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 32 | 
         
            +
                        Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                        Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
         
     | 
| 35 | 
         
            +
                        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 36 | 
         
            +
                        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                        Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
         
     | 
| 39 | 
         
            +
                        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 40 | 
         
            +
                        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                        Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
         
     | 
| 43 | 
         
            +
                        Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                        Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
         
     | 
| 46 | 
         
            +
                        Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
         
     | 
| 47 | 
         
            +
                        )
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    #### load the pre-trained audio_encoder 
         
     | 
| 50 | 
         
            +
                    #self.audio_encoder = self.audio_encoder.to(device)  
         
     | 
| 51 | 
         
            +
                    '''
         
     | 
| 52 | 
         
            +
                    wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
         
     | 
| 53 | 
         
            +
                    state_dict = self.audio_encoder.state_dict()
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    for k,v in wav2lip_state_dict.items():
         
     | 
| 56 | 
         
            +
                        if 'audio_encoder' in k:
         
     | 
| 57 | 
         
            +
                            print('init:', k)
         
     | 
| 58 | 
         
            +
                            state_dict[k.replace('module.audio_encoder.', '')] = v
         
     | 
| 59 | 
         
            +
                    self.audio_encoder.load_state_dict(state_dict)
         
     | 
| 60 | 
         
            +
                    '''
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    self.mapping1 = nn.Linear(512+64+1, 64)
         
     | 
| 63 | 
         
            +
                    #self.mapping2 = nn.Linear(30, 64)
         
     | 
| 64 | 
         
            +
                    #nn.init.constant_(self.mapping1.weight, 0.)
         
     | 
| 65 | 
         
            +
                    nn.init.constant_(self.mapping1.bias, 0.)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def forward(self, x, ref, ratio):
         
     | 
| 68 | 
         
            +
                    x = self.audio_encoder(x).view(x.size(0), -1)
         
     | 
| 69 | 
         
            +
                    ref_reshape = ref.reshape(x.size(0), -1)
         
     | 
| 70 | 
         
            +
                    ratio = ratio.reshape(x.size(0), -1)
         
     | 
| 71 | 
         
            +
                    
         
     | 
| 72 | 
         
            +
                    y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) 
         
     | 
| 73 | 
         
            +
                    out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
         
     | 
| 74 | 
         
            +
                    return out
         
     | 
    	
        src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.94 kB). View file 
     | 
| 
         | 
    	
        src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.37 kB). View file 
     | 
| 
         | 
    	
        src/audio2pose_models/__pycache__/cvae.cpython-38.pyc
    ADDED
    
    | 
         Binary file (4.69 kB). View file 
     | 
| 
         | 
    	
        src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.45 kB). View file 
     | 
| 
         | 
    	
        src/audio2pose_models/__pycache__/networks.cpython-38.pyc
    ADDED
    
    | 
         Binary file (4.74 kB). View file 
     | 
| 
         | 
    	
        src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc
    ADDED
    
    | 
         Binary file (1.91 kB). View file 
     | 
| 
         | 
    	
        src/audio2pose_models/audio2pose.py
    ADDED
    
    | 
         @@ -0,0 +1,93 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch import nn
         
     | 
| 3 | 
         
            +
            from src.audio2pose_models.cvae import CVAE
         
     | 
| 4 | 
         
            +
            from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
         
     | 
| 5 | 
         
            +
            from src.audio2pose_models.audio_encoder import AudioEncoder
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class Audio2Pose(nn.Module):
         
     | 
| 8 | 
         
            +
                def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
         
     | 
| 9 | 
         
            +
                    super().__init__()
         
     | 
| 10 | 
         
            +
                    self.cfg = cfg
         
     | 
| 11 | 
         
            +
                    self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
         
     | 
| 12 | 
         
            +
                    self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
         
     | 
| 13 | 
         
            +
                    self.device = device
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                    self.audio_encoder = AudioEncoder(wav2lip_checkpoint)
         
     | 
| 16 | 
         
            +
                    self.audio_encoder.eval()
         
     | 
| 17 | 
         
            +
                    for param in self.audio_encoder.parameters():
         
     | 
| 18 | 
         
            +
                        param.requires_grad = False
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    self.netG = CVAE(cfg)
         
     | 
| 21 | 
         
            +
                    self.netD_motion = PoseSequenceDiscriminator(cfg)
         
     | 
| 22 | 
         
            +
                    
         
     | 
| 23 | 
         
            +
                    self.gan_criterion = nn.MSELoss()
         
     | 
| 24 | 
         
            +
                    self.reg_criterion = nn.L1Loss(reduction='none')
         
     | 
| 25 | 
         
            +
                    self.pair_criterion = nn.PairwiseDistance()
         
     | 
| 26 | 
         
            +
                    self.cosine_loss = nn.CosineSimilarity(dim=1)
         
     | 
| 27 | 
         
            +
                    
         
     | 
| 28 | 
         
            +
                def forward(self, x):
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    batch = {}
         
     | 
| 31 | 
         
            +
                    coeff_gt = x['gt'].cuda().squeeze(0)           #bs frame_len+1 73
         
     | 
| 32 | 
         
            +
                    batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
         
     | 
| 33 | 
         
            +
                    batch['ref'] = coeff_gt[:, 0, -9:-3]  #bs  6
         
     | 
| 34 | 
         
            +
                    batch['class'] = x['class'].squeeze(0).cuda() # bs
         
     | 
| 35 | 
         
            +
                    indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    # forward
         
     | 
| 38 | 
         
            +
                    audio_emb_list = []
         
     | 
| 39 | 
         
            +
                    audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
         
     | 
| 40 | 
         
            +
                    batch['audio_emb'] = audio_emb
         
     | 
| 41 | 
         
            +
                    batch = self.netG(batch)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    pose_motion_pred = batch['pose_motion_pred']           # bs frame_len 6
         
     | 
| 44 | 
         
            +
                    pose_gt = coeff_gt[:, 1:, -9:-3].clone()               # bs frame_len 6
         
     | 
| 45 | 
         
            +
                    pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred  # bs frame_len 6
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    batch['pose_pred'] = pose_pred
         
     | 
| 48 | 
         
            +
                    batch['pose_gt'] = pose_gt
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    return batch
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def test(self, x):
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    batch = {}
         
     | 
| 55 | 
         
            +
                    ref = x['ref']                            #bs 1 70
         
     | 
| 56 | 
         
            +
                    batch['ref'] = x['ref'][:,0,-6:]  
         
     | 
| 57 | 
         
            +
                    batch['class'] = x['class']  
         
     | 
| 58 | 
         
            +
                    bs = ref.shape[0]
         
     | 
| 59 | 
         
            +
                    
         
     | 
| 60 | 
         
            +
                    indiv_mels= x['indiv_mels']               # bs T 1 80 16
         
     | 
| 61 | 
         
            +
                    indiv_mels_use = indiv_mels[:, 1:]        # we regard the ref as the first frame
         
     | 
| 62 | 
         
            +
                    num_frames = x['num_frames']
         
     | 
| 63 | 
         
            +
                    num_frames = int(num_frames) - 1
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    #  
         
     | 
| 66 | 
         
            +
                    div = num_frames//self.seq_len
         
     | 
| 67 | 
         
            +
                    re = num_frames%self.seq_len
         
     | 
| 68 | 
         
            +
                    audio_emb_list = []
         
     | 
| 69 | 
         
            +
                    pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, 
         
     | 
| 70 | 
         
            +
                                                            device=batch['ref'].device)]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    for i in range(div):
         
     | 
| 73 | 
         
            +
                        z = torch.randn(bs, self.latent_dim).to(ref.device)
         
     | 
| 74 | 
         
            +
                        batch['z'] = z
         
     | 
| 75 | 
         
            +
                        audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
         
     | 
| 76 | 
         
            +
                        batch['audio_emb'] = audio_emb
         
     | 
| 77 | 
         
            +
                        batch = self.netG.test(batch)
         
     | 
| 78 | 
         
            +
                        pose_motion_pred_list.append(batch['pose_motion_pred'])  #list of bs seq_len 6
         
     | 
| 79 | 
         
            +
                    if re != 0:
         
     | 
| 80 | 
         
            +
                        z = torch.randn(bs, self.latent_dim).to(ref.device)
         
     | 
| 81 | 
         
            +
                        batch['z'] = z
         
     | 
| 82 | 
         
            +
                        audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len  512
         
     | 
| 83 | 
         
            +
                        batch['audio_emb'] = audio_emb
         
     | 
| 84 | 
         
            +
                        batch = self.netG.test(batch)
         
     | 
| 85 | 
         
            +
                        pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])   
         
     | 
| 86 | 
         
            +
                    
         
     | 
| 87 | 
         
            +
                    pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
         
     | 
| 88 | 
         
            +
                    batch['pose_motion_pred'] = pose_motion_pred
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    pose_pred = ref[:, :1, -6:] + pose_motion_pred  # bs T 6
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    batch['pose_pred'] = pose_pred
         
     | 
| 93 | 
         
            +
                    return batch
         
     | 
    	
        src/audio2pose_models/audio_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,64 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch import nn
         
     | 
| 3 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class Conv2d(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
         
     | 
| 7 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 8 | 
         
            +
                    self.conv_block = nn.Sequential(
         
     | 
| 9 | 
         
            +
                                        nn.Conv2d(cin, cout, kernel_size, stride, padding),
         
     | 
| 10 | 
         
            +
                                        nn.BatchNorm2d(cout)
         
     | 
| 11 | 
         
            +
                                        )
         
     | 
| 12 | 
         
            +
                    self.act = nn.ReLU()
         
     | 
| 13 | 
         
            +
                    self.residual = residual
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def forward(self, x):
         
     | 
| 16 | 
         
            +
                    out = self.conv_block(x)
         
     | 
| 17 | 
         
            +
                    if self.residual:
         
     | 
| 18 | 
         
            +
                        out += x
         
     | 
| 19 | 
         
            +
                    return self.act(out)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class AudioEncoder(nn.Module):
         
     | 
| 22 | 
         
            +
                def __init__(self, wav2lip_checkpoint):
         
     | 
| 23 | 
         
            +
                    super(AudioEncoder, self).__init__()
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    self.audio_encoder = nn.Sequential(
         
     | 
| 26 | 
         
            +
                        Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 27 | 
         
            +
                        Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 28 | 
         
            +
                        Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                        Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
         
     | 
| 31 | 
         
            +
                        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 32 | 
         
            +
                        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                        Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
         
     | 
| 35 | 
         
            +
                        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 36 | 
         
            +
                        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                        Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
         
     | 
| 39 | 
         
            +
                        Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                        Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
         
     | 
| 42 | 
         
            +
                        Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    #### load the pre-trained audio_encoder\
         
     | 
| 45 | 
         
            +
                    wav2lip_state_dict = torch.load(wav2lip_checkpoint)['state_dict']
         
     | 
| 46 | 
         
            +
                    state_dict = self.audio_encoder.state_dict()
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    for k,v in wav2lip_state_dict.items():
         
     | 
| 49 | 
         
            +
                        if 'audio_encoder' in k:
         
     | 
| 50 | 
         
            +
                            state_dict[k.replace('module.audio_encoder.', '')] = v
         
     | 
| 51 | 
         
            +
                    self.audio_encoder.load_state_dict(state_dict)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def forward(self, audio_sequences):
         
     | 
| 55 | 
         
            +
                    # audio_sequences = (B, T, 1, 80, 16)
         
     | 
| 56 | 
         
            +
                    B = audio_sequences.size(0)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
         
     | 
| 61 | 
         
            +
                    dim = audio_embedding.shape[1]
         
     | 
| 62 | 
         
            +
                    audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 
         
     | 
    	
        src/audio2pose_models/cvae.py
    ADDED
    
    | 
         @@ -0,0 +1,149 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            +
            from torch import nn
         
     | 
| 4 | 
         
            +
            from src.audio2pose_models.res_unet import ResUnet
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            def class2onehot(idx, class_num):
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                assert torch.max(idx).item() < class_num
         
     | 
| 9 | 
         
            +
                onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
         
     | 
| 10 | 
         
            +
                onehot.scatter_(1, idx, 1)
         
     | 
| 11 | 
         
            +
                return onehot
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class CVAE(nn.Module):
         
     | 
| 14 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 15 | 
         
            +
                    super().__init__()
         
     | 
| 16 | 
         
            +
                    encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
         
     | 
| 17 | 
         
            +
                    decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
         
     | 
| 18 | 
         
            +
                    latent_size = cfg.MODEL.CVAE.LATENT_SIZE
         
     | 
| 19 | 
         
            +
                    num_classes = cfg.DATASET.NUM_CLASSES
         
     | 
| 20 | 
         
            +
                    audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
         
     | 
| 21 | 
         
            +
                    audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
         
     | 
| 22 | 
         
            +
                    seq_len = cfg.MODEL.CVAE.SEQ_LEN
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.latent_size = latent_size
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
         
     | 
| 27 | 
         
            +
                                            audio_emb_in_size, audio_emb_out_size, seq_len)
         
     | 
| 28 | 
         
            +
                    self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
         
     | 
| 29 | 
         
            +
                                            audio_emb_in_size, audio_emb_out_size, seq_len)
         
     | 
| 30 | 
         
            +
                def reparameterize(self, mu, logvar):
         
     | 
| 31 | 
         
            +
                    std = torch.exp(0.5 * logvar)
         
     | 
| 32 | 
         
            +
                    eps = torch.randn_like(std)
         
     | 
| 33 | 
         
            +
                    return mu + eps * std
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def forward(self, batch):
         
     | 
| 36 | 
         
            +
                    batch = self.encoder(batch)
         
     | 
| 37 | 
         
            +
                    mu = batch['mu']
         
     | 
| 38 | 
         
            +
                    logvar = batch['logvar']
         
     | 
| 39 | 
         
            +
                    z = self.reparameterize(mu, logvar)
         
     | 
| 40 | 
         
            +
                    batch['z'] = z
         
     | 
| 41 | 
         
            +
                    return self.decoder(batch)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def test(self, batch):
         
     | 
| 44 | 
         
            +
                    '''
         
     | 
| 45 | 
         
            +
                    class_id = batch['class']
         
     | 
| 46 | 
         
            +
                    z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
         
     | 
| 47 | 
         
            +
                    batch['z'] = z
         
     | 
| 48 | 
         
            +
                    '''
         
     | 
| 49 | 
         
            +
                    return self.decoder(batch)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            class ENCODER(nn.Module):
         
     | 
| 52 | 
         
            +
                def __init__(self, layer_sizes, latent_size, num_classes, 
         
     | 
| 53 | 
         
            +
                            audio_emb_in_size, audio_emb_out_size, seq_len):
         
     | 
| 54 | 
         
            +
                    super().__init__()
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    self.resunet = ResUnet()
         
     | 
| 57 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 58 | 
         
            +
                    self.seq_len = seq_len
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    self.MLP = nn.Sequential()
         
     | 
| 61 | 
         
            +
                    layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
         
     | 
| 62 | 
         
            +
                    for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
         
     | 
| 63 | 
         
            +
                        self.MLP.add_module(
         
     | 
| 64 | 
         
            +
                            name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
         
     | 
| 65 | 
         
            +
                        self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
         
     | 
| 68 | 
         
            +
                    self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
         
     | 
| 69 | 
         
            +
                    self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def forward(self, batch):
         
     | 
| 74 | 
         
            +
                    class_id = batch['class']
         
     | 
| 75 | 
         
            +
                    pose_motion_gt = batch['pose_motion_gt']                             #bs seq_len 6
         
     | 
| 76 | 
         
            +
                    ref = batch['ref']                             #bs 6
         
     | 
| 77 | 
         
            +
                    bs = pose_motion_gt.shape[0]
         
     | 
| 78 | 
         
            +
                    audio_in = batch['audio_emb']                          # bs seq_len audio_emb_in_size
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    #pose encode
         
     | 
| 81 | 
         
            +
                    pose_emb = self.resunet(pose_motion_gt.unsqueeze(1))          #bs 1 seq_len 6 
         
     | 
| 82 | 
         
            +
                    pose_emb = pose_emb.reshape(bs, -1)                    #bs seq_len*6
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    #audio mapping
         
     | 
| 85 | 
         
            +
                    print(audio_in.shape)
         
     | 
| 86 | 
         
            +
                    audio_out = self.linear_audio(audio_in)                # bs seq_len audio_emb_out_size
         
     | 
| 87 | 
         
            +
                    audio_out = audio_out.reshape(bs, -1)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    class_bias = self.classbias[class_id]                  #bs latent_size
         
     | 
| 90 | 
         
            +
                    x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
         
     | 
| 91 | 
         
            +
                    x_out = self.MLP(x_in)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    mu = self.linear_means(x_out)
         
     | 
| 94 | 
         
            +
                    logvar = self.linear_means(x_out)                      #bs latent_size 
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    batch.update({'mu':mu, 'logvar':logvar})
         
     | 
| 97 | 
         
            +
                    return batch
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            class DECODER(nn.Module):
         
     | 
| 100 | 
         
            +
                def __init__(self, layer_sizes, latent_size, num_classes, 
         
     | 
| 101 | 
         
            +
                            audio_emb_in_size, audio_emb_out_size, seq_len):
         
     | 
| 102 | 
         
            +
                    super().__init__()
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    self.resunet = ResUnet()
         
     | 
| 105 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 106 | 
         
            +
                    self.seq_len = seq_len
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    self.MLP = nn.Sequential()
         
     | 
| 109 | 
         
            +
                    input_size = latent_size + seq_len*audio_emb_out_size + 6
         
     | 
| 110 | 
         
            +
                    for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
         
     | 
| 111 | 
         
            +
                        self.MLP.add_module(
         
     | 
| 112 | 
         
            +
                            name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
         
     | 
| 113 | 
         
            +
                        if i+1 < len(layer_sizes):
         
     | 
| 114 | 
         
            +
                            self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
         
     | 
| 115 | 
         
            +
                        else:
         
     | 
| 116 | 
         
            +
                            self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
         
     | 
| 117 | 
         
            +
                    
         
     | 
| 118 | 
         
            +
                    self.pose_linear = nn.Linear(6, 6)
         
     | 
| 119 | 
         
            +
                    self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                def forward(self, batch):
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    z = batch['z']                                          #bs latent_size
         
     | 
| 126 | 
         
            +
                    bs = z.shape[0]
         
     | 
| 127 | 
         
            +
                    class_id = batch['class']
         
     | 
| 128 | 
         
            +
                    ref = batch['ref']                             #bs 6
         
     | 
| 129 | 
         
            +
                    audio_in = batch['audio_emb']                           # bs seq_len audio_emb_in_size
         
     | 
| 130 | 
         
            +
                    #print('audio_in: ', audio_in[:, :, :10])
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    audio_out = self.linear_audio(audio_in)                 # bs seq_len audio_emb_out_size
         
     | 
| 133 | 
         
            +
                    #print('audio_out: ', audio_out[:, :, :10])
         
     | 
| 134 | 
         
            +
                    audio_out = audio_out.reshape([bs, -1])                 # bs seq_len*audio_emb_out_size
         
     | 
| 135 | 
         
            +
                    class_bias = self.classbias[class_id]                   #bs latent_size
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    z = z + class_bias
         
     | 
| 138 | 
         
            +
                    x_in = torch.cat([ref, z, audio_out], dim=-1)
         
     | 
| 139 | 
         
            +
                    x_out = self.MLP(x_in)                                  # bs layer_sizes[-1]
         
     | 
| 140 | 
         
            +
                    x_out = x_out.reshape((bs, self.seq_len, -1))
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    #print('x_out: ', x_out)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    pose_emb = self.resunet(x_out.unsqueeze(1))             #bs 1 seq_len 6
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    pose_motion_pred = self.pose_linear(pose_emb.squeeze(1))       #bs seq_len 6
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    batch.update({'pose_motion_pred':pose_motion_pred})
         
     | 
| 149 | 
         
            +
                    return batch
         
     | 
    	
        src/audio2pose_models/discriminator.py
    ADDED
    
    | 
         @@ -0,0 +1,76 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            +
            from torch import nn
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class ConvNormRelu(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
         
     | 
| 7 | 
         
            +
                             kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
         
     | 
| 8 | 
         
            +
                    super().__init__()
         
     | 
| 9 | 
         
            +
                    if kernel_size is None:
         
     | 
| 10 | 
         
            +
                        if downsample:
         
     | 
| 11 | 
         
            +
                            kernel_size, stride, padding = 4, 2, 1
         
     | 
| 12 | 
         
            +
                        else:
         
     | 
| 13 | 
         
            +
                            kernel_size, stride, padding = 3, 1, 1
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                    if conv_type == '2d':
         
     | 
| 16 | 
         
            +
                        self.conv = nn.Conv2d(
         
     | 
| 17 | 
         
            +
                            in_channels,
         
     | 
| 18 | 
         
            +
                            out_channels,
         
     | 
| 19 | 
         
            +
                            kernel_size,
         
     | 
| 20 | 
         
            +
                            stride,
         
     | 
| 21 | 
         
            +
                            padding,
         
     | 
| 22 | 
         
            +
                            bias=False,
         
     | 
| 23 | 
         
            +
                        )
         
     | 
| 24 | 
         
            +
                        if norm == 'BN':
         
     | 
| 25 | 
         
            +
                            self.norm = nn.BatchNorm2d(out_channels)
         
     | 
| 26 | 
         
            +
                        elif norm == 'IN':
         
     | 
| 27 | 
         
            +
                            self.norm = nn.InstanceNorm2d(out_channels)
         
     | 
| 28 | 
         
            +
                        else:
         
     | 
| 29 | 
         
            +
                            raise NotImplementedError
         
     | 
| 30 | 
         
            +
                    elif conv_type == '1d':
         
     | 
| 31 | 
         
            +
                        self.conv = nn.Conv1d(
         
     | 
| 32 | 
         
            +
                            in_channels,
         
     | 
| 33 | 
         
            +
                            out_channels,
         
     | 
| 34 | 
         
            +
                            kernel_size,
         
     | 
| 35 | 
         
            +
                            stride,
         
     | 
| 36 | 
         
            +
                            padding,
         
     | 
| 37 | 
         
            +
                            bias=False,
         
     | 
| 38 | 
         
            +
                        )
         
     | 
| 39 | 
         
            +
                        if norm == 'BN':
         
     | 
| 40 | 
         
            +
                            self.norm = nn.BatchNorm1d(out_channels)
         
     | 
| 41 | 
         
            +
                        elif norm == 'IN':
         
     | 
| 42 | 
         
            +
                            self.norm = nn.InstanceNorm1d(out_channels)
         
     | 
| 43 | 
         
            +
                        else:
         
     | 
| 44 | 
         
            +
                            raise NotImplementedError
         
     | 
| 45 | 
         
            +
                    nn.init.kaiming_normal_(self.conv.weight)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def forward(self, x):
         
     | 
| 50 | 
         
            +
                    x = self.conv(x)
         
     | 
| 51 | 
         
            +
                    if isinstance(self.norm, nn.InstanceNorm1d):
         
     | 
| 52 | 
         
            +
                        x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1))  # normalize on [C]
         
     | 
| 53 | 
         
            +
                    else:
         
     | 
| 54 | 
         
            +
                        x = self.norm(x)
         
     | 
| 55 | 
         
            +
                    x = self.act(x)
         
     | 
| 56 | 
         
            +
                    return x
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            class PoseSequenceDiscriminator(nn.Module):
         
     | 
| 60 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 61 | 
         
            +
                    super().__init__()
         
     | 
| 62 | 
         
            +
                    self.cfg = cfg
         
     | 
| 63 | 
         
            +
                    leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    self.seq = nn.Sequential(
         
     | 
| 66 | 
         
            +
                        ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky),  # B, 256, 64
         
     | 
| 67 | 
         
            +
                        ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky),  # B, 512, 32
         
     | 
| 68 | 
         
            +
                        ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky),  # B, 1024, 16
         
     | 
| 69 | 
         
            +
                        nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True)  # B, 1, 16
         
     | 
| 70 | 
         
            +
                    )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def forward(self, x):
         
     | 
| 73 | 
         
            +
                    x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
         
     | 
| 74 | 
         
            +
                    x = self.seq(x)
         
     | 
| 75 | 
         
            +
                    x = x.squeeze(1)
         
     | 
| 76 | 
         
            +
                    return x
         
     | 
    	
        src/audio2pose_models/networks.py
    ADDED
    
    | 
         @@ -0,0 +1,140 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch.nn as nn
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class ResidualConv(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, input_dim, output_dim, stride, padding):
         
     | 
| 7 | 
         
            +
                    super(ResidualConv, self).__init__()
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                    self.conv_block = nn.Sequential(
         
     | 
| 10 | 
         
            +
                        nn.BatchNorm2d(input_dim),
         
     | 
| 11 | 
         
            +
                        nn.ReLU(),
         
     | 
| 12 | 
         
            +
                        nn.Conv2d(
         
     | 
| 13 | 
         
            +
                            input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
         
     | 
| 14 | 
         
            +
                        ),
         
     | 
| 15 | 
         
            +
                        nn.BatchNorm2d(output_dim),
         
     | 
| 16 | 
         
            +
                        nn.ReLU(),
         
     | 
| 17 | 
         
            +
                        nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
         
     | 
| 18 | 
         
            +
                    )
         
     | 
| 19 | 
         
            +
                    self.conv_skip = nn.Sequential(
         
     | 
| 20 | 
         
            +
                        nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
         
     | 
| 21 | 
         
            +
                        nn.BatchNorm2d(output_dim),
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def forward(self, x):
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    return self.conv_block(x) + self.conv_skip(x)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class Upsample(nn.Module):
         
     | 
| 30 | 
         
            +
                def __init__(self, input_dim, output_dim, kernel, stride):
         
     | 
| 31 | 
         
            +
                    super(Upsample, self).__init__()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    self.upsample = nn.ConvTranspose2d(
         
     | 
| 34 | 
         
            +
                        input_dim, output_dim, kernel_size=kernel, stride=stride
         
     | 
| 35 | 
         
            +
                    )
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def forward(self, x):
         
     | 
| 38 | 
         
            +
                    return self.upsample(x)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            class Squeeze_Excite_Block(nn.Module):
         
     | 
| 42 | 
         
            +
                def __init__(self, channel, reduction=16):
         
     | 
| 43 | 
         
            +
                    super(Squeeze_Excite_Block, self).__init__()
         
     | 
| 44 | 
         
            +
                    self.avg_pool = nn.AdaptiveAvgPool2d(1)
         
     | 
| 45 | 
         
            +
                    self.fc = nn.Sequential(
         
     | 
| 46 | 
         
            +
                        nn.Linear(channel, channel // reduction, bias=False),
         
     | 
| 47 | 
         
            +
                        nn.ReLU(inplace=True),
         
     | 
| 48 | 
         
            +
                        nn.Linear(channel // reduction, channel, bias=False),
         
     | 
| 49 | 
         
            +
                        nn.Sigmoid(),
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def forward(self, x):
         
     | 
| 53 | 
         
            +
                    b, c, _, _ = x.size()
         
     | 
| 54 | 
         
            +
                    y = self.avg_pool(x).view(b, c)
         
     | 
| 55 | 
         
            +
                    y = self.fc(y).view(b, c, 1, 1)
         
     | 
| 56 | 
         
            +
                    return x * y.expand_as(x)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            class ASPP(nn.Module):
         
     | 
| 60 | 
         
            +
                def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
         
     | 
| 61 | 
         
            +
                    super(ASPP, self).__init__()
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    self.aspp_block1 = nn.Sequential(
         
     | 
| 64 | 
         
            +
                        nn.Conv2d(
         
     | 
| 65 | 
         
            +
                            in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
         
     | 
| 66 | 
         
            +
                        ),
         
     | 
| 67 | 
         
            +
                        nn.ReLU(inplace=True),
         
     | 
| 68 | 
         
            +
                        nn.BatchNorm2d(out_dims),
         
     | 
| 69 | 
         
            +
                    )
         
     | 
| 70 | 
         
            +
                    self.aspp_block2 = nn.Sequential(
         
     | 
| 71 | 
         
            +
                        nn.Conv2d(
         
     | 
| 72 | 
         
            +
                            in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
         
     | 
| 73 | 
         
            +
                        ),
         
     | 
| 74 | 
         
            +
                        nn.ReLU(inplace=True),
         
     | 
| 75 | 
         
            +
                        nn.BatchNorm2d(out_dims),
         
     | 
| 76 | 
         
            +
                    )
         
     | 
| 77 | 
         
            +
                    self.aspp_block3 = nn.Sequential(
         
     | 
| 78 | 
         
            +
                        nn.Conv2d(
         
     | 
| 79 | 
         
            +
                            in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
         
     | 
| 80 | 
         
            +
                        ),
         
     | 
| 81 | 
         
            +
                        nn.ReLU(inplace=True),
         
     | 
| 82 | 
         
            +
                        nn.BatchNorm2d(out_dims),
         
     | 
| 83 | 
         
            +
                    )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
         
     | 
| 86 | 
         
            +
                    self._init_weights()
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def forward(self, x):
         
     | 
| 89 | 
         
            +
                    x1 = self.aspp_block1(x)
         
     | 
| 90 | 
         
            +
                    x2 = self.aspp_block2(x)
         
     | 
| 91 | 
         
            +
                    x3 = self.aspp_block3(x)
         
     | 
| 92 | 
         
            +
                    out = torch.cat([x1, x2, x3], dim=1)
         
     | 
| 93 | 
         
            +
                    return self.output(out)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def _init_weights(self):
         
     | 
| 96 | 
         
            +
                    for m in self.modules():
         
     | 
| 97 | 
         
            +
                        if isinstance(m, nn.Conv2d):
         
     | 
| 98 | 
         
            +
                            nn.init.kaiming_normal_(m.weight)
         
     | 
| 99 | 
         
            +
                        elif isinstance(m, nn.BatchNorm2d):
         
     | 
| 100 | 
         
            +
                            m.weight.data.fill_(1)
         
     | 
| 101 | 
         
            +
                            m.bias.data.zero_()
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            class Upsample_(nn.Module):
         
     | 
| 105 | 
         
            +
                def __init__(self, scale=2):
         
     | 
| 106 | 
         
            +
                    super(Upsample_, self).__init__()
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def forward(self, x):
         
     | 
| 111 | 
         
            +
                    return self.upsample(x)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            class AttentionBlock(nn.Module):
         
     | 
| 115 | 
         
            +
                def __init__(self, input_encoder, input_decoder, output_dim):
         
     | 
| 116 | 
         
            +
                    super(AttentionBlock, self).__init__()
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    self.conv_encoder = nn.Sequential(
         
     | 
| 119 | 
         
            +
                        nn.BatchNorm2d(input_encoder),
         
     | 
| 120 | 
         
            +
                        nn.ReLU(),
         
     | 
| 121 | 
         
            +
                        nn.Conv2d(input_encoder, output_dim, 3, padding=1),
         
     | 
| 122 | 
         
            +
                        nn.MaxPool2d(2, 2),
         
     | 
| 123 | 
         
            +
                    )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    self.conv_decoder = nn.Sequential(
         
     | 
| 126 | 
         
            +
                        nn.BatchNorm2d(input_decoder),
         
     | 
| 127 | 
         
            +
                        nn.ReLU(),
         
     | 
| 128 | 
         
            +
                        nn.Conv2d(input_decoder, output_dim, 3, padding=1),
         
     | 
| 129 | 
         
            +
                    )
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    self.conv_attn = nn.Sequential(
         
     | 
| 132 | 
         
            +
                        nn.BatchNorm2d(output_dim),
         
     | 
| 133 | 
         
            +
                        nn.ReLU(),
         
     | 
| 134 | 
         
            +
                        nn.Conv2d(output_dim, 1, 1),
         
     | 
| 135 | 
         
            +
                    )
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def forward(self, x1, x2):
         
     | 
| 138 | 
         
            +
                    out = self.conv_encoder(x1) + self.conv_decoder(x2)
         
     | 
| 139 | 
         
            +
                    out = self.conv_attn(out)
         
     | 
| 140 | 
         
            +
                    return out * x2
         
     | 
    	
        src/audio2pose_models/res_unet.py
    ADDED
    
    | 
         @@ -0,0 +1,65 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            from src.audio2pose_models.networks import ResidualConv, Upsample
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class ResUnet(nn.Module):
         
     | 
| 7 | 
         
            +
                def __init__(self, channel=1, filters=[32, 64, 128, 256]):
         
     | 
| 8 | 
         
            +
                    super(ResUnet, self).__init__()
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                    self.input_layer = nn.Sequential(
         
     | 
| 11 | 
         
            +
                        nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
         
     | 
| 12 | 
         
            +
                        nn.BatchNorm2d(filters[0]),
         
     | 
| 13 | 
         
            +
                        nn.ReLU(),
         
     | 
| 14 | 
         
            +
                        nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
         
     | 
| 15 | 
         
            +
                    )
         
     | 
| 16 | 
         
            +
                    self.input_skip = nn.Sequential(
         
     | 
| 17 | 
         
            +
                        nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
         
     | 
| 18 | 
         
            +
                    )
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
         
     | 
| 21 | 
         
            +
                    self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                    self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
         
     | 
| 26 | 
         
            +
                    self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
         
     | 
| 29 | 
         
            +
                    self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
         
     | 
| 32 | 
         
            +
                    self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.output_layer = nn.Sequential(
         
     | 
| 35 | 
         
            +
                        nn.Conv2d(filters[0], 1, 1, 1),
         
     | 
| 36 | 
         
            +
                        nn.Sigmoid(),
         
     | 
| 37 | 
         
            +
                    )
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def forward(self, x):
         
     | 
| 40 | 
         
            +
                    # Encode
         
     | 
| 41 | 
         
            +
                    x1 = self.input_layer(x) + self.input_skip(x)
         
     | 
| 42 | 
         
            +
                    x2 = self.residual_conv_1(x1)
         
     | 
| 43 | 
         
            +
                    x3 = self.residual_conv_2(x2)
         
     | 
| 44 | 
         
            +
                    # Bridge
         
     | 
| 45 | 
         
            +
                    x4 = self.bridge(x3)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    # Decode
         
     | 
| 48 | 
         
            +
                    x4 = self.upsample_1(x4)
         
     | 
| 49 | 
         
            +
                    x5 = torch.cat([x4, x3], dim=1)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    x6 = self.up_residual_conv1(x5)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    x6 = self.upsample_2(x6)
         
     | 
| 54 | 
         
            +
                    x7 = torch.cat([x6, x2], dim=1)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    x8 = self.up_residual_conv2(x7)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    x8 = self.upsample_3(x8)
         
     | 
| 59 | 
         
            +
                    x9 = torch.cat([x8, x1], dim=1)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    x10 = self.up_residual_conv3(x9)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    output = self.output_layer(x10)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    return output
         
     | 
    	
        src/config/auido2exp.yaml
    ADDED
    
    | 
         @@ -0,0 +1,58 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            DATASET:
         
     | 
| 2 | 
         
            +
              TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
         
     | 
| 3 | 
         
            +
              EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
         
     | 
| 4 | 
         
            +
              TRAIN_BATCH_SIZE: 32
         
     | 
| 5 | 
         
            +
              EVAL_BATCH_SIZE: 32
         
     | 
| 6 | 
         
            +
              EXP: True
         
     | 
| 7 | 
         
            +
              EXP_DIM: 64
         
     | 
| 8 | 
         
            +
              FRAME_LEN: 32
         
     | 
| 9 | 
         
            +
              COEFF_LEN: 73
         
     | 
| 10 | 
         
            +
              NUM_CLASSES: 46
         
     | 
| 11 | 
         
            +
              AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
         
     | 
| 12 | 
         
            +
              COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
         
     | 
| 13 | 
         
            +
              LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
         
     | 
| 14 | 
         
            +
              DEBUG: True
         
     | 
| 15 | 
         
            +
              NUM_REPEATS: 2
         
     | 
| 16 | 
         
            +
              T: 40
         
     | 
| 17 | 
         
            +
              
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            MODEL:
         
     | 
| 20 | 
         
            +
              FRAMEWORK: V2
         
     | 
| 21 | 
         
            +
              AUDIOENCODER:
         
     | 
| 22 | 
         
            +
                LEAKY_RELU: True
         
     | 
| 23 | 
         
            +
                NORM: 'IN'
         
     | 
| 24 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 25 | 
         
            +
                LEAKY_RELU: False
         
     | 
| 26 | 
         
            +
                INPUT_CHANNELS: 6
         
     | 
| 27 | 
         
            +
              CVAE:
         
     | 
| 28 | 
         
            +
                AUDIO_EMB_IN_SIZE: 512
         
     | 
| 29 | 
         
            +
                AUDIO_EMB_OUT_SIZE: 128
         
     | 
| 30 | 
         
            +
                SEQ_LEN: 32
         
     | 
| 31 | 
         
            +
                LATENT_SIZE: 256
         
     | 
| 32 | 
         
            +
                ENCODER_LAYER_SIZES: [192, 1024]
         
     | 
| 33 | 
         
            +
                DECODER_LAYER_SIZES: [1024, 192]
         
     | 
| 34 | 
         
            +
                
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            TRAIN:
         
     | 
| 37 | 
         
            +
              MAX_EPOCH: 300
         
     | 
| 38 | 
         
            +
              GENERATOR:
         
     | 
| 39 | 
         
            +
                LR: 2.0e-5
         
     | 
| 40 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 41 | 
         
            +
                LR: 1.0e-5
         
     | 
| 42 | 
         
            +
              LOSS:
         
     | 
| 43 | 
         
            +
                W_FEAT: 0
         
     | 
| 44 | 
         
            +
                W_COEFF_EXP: 2
         
     | 
| 45 | 
         
            +
                W_LM: 1.0e-2
         
     | 
| 46 | 
         
            +
                W_LM_MOUTH: 0
         
     | 
| 47 | 
         
            +
                W_REG: 0
         
     | 
| 48 | 
         
            +
                W_SYNC: 0
         
     | 
| 49 | 
         
            +
                W_COLOR: 0
         
     | 
| 50 | 
         
            +
                W_EXPRESSION: 0
         
     | 
| 51 | 
         
            +
                W_LIPREADING: 0.01
         
     | 
| 52 | 
         
            +
                W_LIPREADING_VV: 0
         
     | 
| 53 | 
         
            +
                W_EYE_BLINK: 4
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            TAG:
         
     | 
| 56 | 
         
            +
              NAME:  small_dataset
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
    	
        src/config/auido2pose.yaml
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            DATASET:
         
     | 
| 2 | 
         
            +
              TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
         
     | 
| 3 | 
         
            +
              EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
         
     | 
| 4 | 
         
            +
              TRAIN_BATCH_SIZE: 64
         
     | 
| 5 | 
         
            +
              EVAL_BATCH_SIZE: 1
         
     | 
| 6 | 
         
            +
              EXP: True
         
     | 
| 7 | 
         
            +
              EXP_DIM: 64
         
     | 
| 8 | 
         
            +
              FRAME_LEN: 32
         
     | 
| 9 | 
         
            +
              COEFF_LEN: 73
         
     | 
| 10 | 
         
            +
              NUM_CLASSES: 46
         
     | 
| 11 | 
         
            +
              AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
         
     | 
| 12 | 
         
            +
              COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
         
     | 
| 13 | 
         
            +
              DEBUG: True
         
     | 
| 14 | 
         
            +
              
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            MODEL:
         
     | 
| 17 | 
         
            +
              AUDIOENCODER:
         
     | 
| 18 | 
         
            +
                LEAKY_RELU: True
         
     | 
| 19 | 
         
            +
                NORM: 'IN'
         
     | 
| 20 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 21 | 
         
            +
                LEAKY_RELU: False
         
     | 
| 22 | 
         
            +
                INPUT_CHANNELS: 6
         
     | 
| 23 | 
         
            +
              CVAE:
         
     | 
| 24 | 
         
            +
                AUDIO_EMB_IN_SIZE: 512
         
     | 
| 25 | 
         
            +
                AUDIO_EMB_OUT_SIZE: 6
         
     | 
| 26 | 
         
            +
                SEQ_LEN: 32
         
     | 
| 27 | 
         
            +
                LATENT_SIZE: 64
         
     | 
| 28 | 
         
            +
                ENCODER_LAYER_SIZES: [192, 128]
         
     | 
| 29 | 
         
            +
                DECODER_LAYER_SIZES: [128, 192]
         
     | 
| 30 | 
         
            +
                
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            TRAIN:
         
     | 
| 33 | 
         
            +
              MAX_EPOCH: 150
         
     | 
| 34 | 
         
            +
              GENERATOR:
         
     | 
| 35 | 
         
            +
                LR: 1.0e-4
         
     | 
| 36 | 
         
            +
              DISCRIMINATOR:
         
     | 
| 37 | 
         
            +
                LR: 1.0e-4
         
     | 
| 38 | 
         
            +
              LOSS:
         
     | 
| 39 | 
         
            +
                LAMBDA_REG: 1
         
     | 
| 40 | 
         
            +
                LAMBDA_LANDMARKS: 0
         
     | 
| 41 | 
         
            +
                LAMBDA_VERTICES: 0
         
     | 
| 42 | 
         
            +
                LAMBDA_GAN_MOTION: 0.7
         
     | 
| 43 | 
         
            +
                LAMBDA_GAN_COEFF: 0
         
     | 
| 44 | 
         
            +
                LAMBDA_KL: 1
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            TAG:
         
     | 
| 47 | 
         
            +
              NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
    	
        src/config/facerender.yaml
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_params:
         
     | 
| 2 | 
         
            +
              common_params:
         
     | 
| 3 | 
         
            +
                num_kp: 15 
         
     | 
| 4 | 
         
            +
                image_channel: 3                    
         
     | 
| 5 | 
         
            +
                feature_channel: 32
         
     | 
| 6 | 
         
            +
                estimate_jacobian: False   # True
         
     | 
| 7 | 
         
            +
              kp_detector_params:
         
     | 
| 8 | 
         
            +
                 temperature: 0.1
         
     | 
| 9 | 
         
            +
                 block_expansion: 32            
         
     | 
| 10 | 
         
            +
                 max_features: 1024
         
     | 
| 11 | 
         
            +
                 scale_factor: 0.25         # 0.25
         
     | 
| 12 | 
         
            +
                 num_blocks: 5
         
     | 
| 13 | 
         
            +
                 reshape_channel: 16384  # 16384 = 1024 * 16
         
     | 
| 14 | 
         
            +
                 reshape_depth: 16
         
     | 
| 15 | 
         
            +
              he_estimator_params:
         
     | 
| 16 | 
         
            +
                 block_expansion: 64            
         
     | 
| 17 | 
         
            +
                 max_features: 2048
         
     | 
| 18 | 
         
            +
                 num_bins: 66
         
     | 
| 19 | 
         
            +
              generator_params:
         
     | 
| 20 | 
         
            +
                block_expansion: 64
         
     | 
| 21 | 
         
            +
                max_features: 512
         
     | 
| 22 | 
         
            +
                num_down_blocks: 2
         
     | 
| 23 | 
         
            +
                reshape_channel: 32
         
     | 
| 24 | 
         
            +
                reshape_depth: 16         # 512 = 32 * 16
         
     | 
| 25 | 
         
            +
                num_resblocks: 6
         
     | 
| 26 | 
         
            +
                estimate_occlusion_map: True
         
     | 
| 27 | 
         
            +
                dense_motion_params:
         
     | 
| 28 | 
         
            +
                  block_expansion: 32
         
     | 
| 29 | 
         
            +
                  max_features: 1024
         
     | 
| 30 | 
         
            +
                  num_blocks: 5
         
     | 
| 31 | 
         
            +
                  reshape_depth: 16
         
     | 
| 32 | 
         
            +
                  compress: 4
         
     | 
| 33 | 
         
            +
              discriminator_params:
         
     | 
| 34 | 
         
            +
                scales: [1]
         
     | 
| 35 | 
         
            +
                block_expansion: 32                 
         
     | 
| 36 | 
         
            +
                max_features: 512
         
     | 
| 37 | 
         
            +
                num_blocks: 4
         
     | 
| 38 | 
         
            +
                sn: True
         
     | 
| 39 | 
         
            +
              mapping_params:
         
     | 
| 40 | 
         
            +
                  coeff_nc: 70
         
     | 
| 41 | 
         
            +
                  descriptor_nc: 1024
         
     | 
| 42 | 
         
            +
                  layer: 3
         
     | 
| 43 | 
         
            +
                  num_kp: 15
         
     | 
| 44 | 
         
            +
                  num_bins: 66
         
     | 
| 45 | 
         
            +
             
     | 
    	
        src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc
    ADDED
    
    | 
         Binary file (3.57 kB). View file 
     | 
| 
         | 
    	
        src/face3d/__pycache__/visualize.cpython-38.pyc
    ADDED
    
    | 
         Binary file (1.7 kB). View file 
     | 
| 
         | 
    	
        src/face3d/data/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,116 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This package includes all the modules related to data loading and preprocessing
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
         
     | 
| 4 | 
         
            +
             You need to implement four functions:
         
     | 
| 5 | 
         
            +
                -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
         
     | 
| 6 | 
         
            +
                -- <__len__>:                       return the size of dataset.
         
     | 
| 7 | 
         
            +
                -- <__getitem__>:                   get a data point from data loader.
         
     | 
| 8 | 
         
            +
                -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
         
     | 
| 11 | 
         
            +
            See our template dataset class 'template_dataset.py' for more details.
         
     | 
| 12 | 
         
            +
            """
         
     | 
| 13 | 
         
            +
            import numpy as np
         
     | 
| 14 | 
         
            +
            import importlib
         
     | 
| 15 | 
         
            +
            import torch.utils.data
         
     | 
| 16 | 
         
            +
            from face3d.data.base_dataset import BaseDataset
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def find_dataset_using_name(dataset_name):
         
     | 
| 20 | 
         
            +
                """Import the module "data/[dataset_name]_dataset.py".
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                In the file, the class called DatasetNameDataset() will
         
     | 
| 23 | 
         
            +
                be instantiated. It has to be a subclass of BaseDataset,
         
     | 
| 24 | 
         
            +
                and it is case-insensitive.
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                dataset_filename = "data." + dataset_name + "_dataset"
         
     | 
| 27 | 
         
            +
                datasetlib = importlib.import_module(dataset_filename)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                dataset = None
         
     | 
| 30 | 
         
            +
                target_dataset_name = dataset_name.replace('_', '') + 'dataset'
         
     | 
| 31 | 
         
            +
                for name, cls in datasetlib.__dict__.items():
         
     | 
| 32 | 
         
            +
                    if name.lower() == target_dataset_name.lower() \
         
     | 
| 33 | 
         
            +
                       and issubclass(cls, BaseDataset):
         
     | 
| 34 | 
         
            +
                        dataset = cls
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                if dataset is None:
         
     | 
| 37 | 
         
            +
                    raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                return dataset
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def get_option_setter(dataset_name):
         
     | 
| 43 | 
         
            +
                """Return the static method <modify_commandline_options> of the dataset class."""
         
     | 
| 44 | 
         
            +
                dataset_class = find_dataset_using_name(dataset_name)
         
     | 
| 45 | 
         
            +
                return dataset_class.modify_commandline_options
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def create_dataset(opt, rank=0):
         
     | 
| 49 | 
         
            +
                """Create a dataset given the option.
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                This function wraps the class CustomDatasetDataLoader.
         
     | 
| 52 | 
         
            +
                    This is the main interface between this package and 'train.py'/'test.py'
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                Example:
         
     | 
| 55 | 
         
            +
                    >>> from data import create_dataset
         
     | 
| 56 | 
         
            +
                    >>> dataset = create_dataset(opt)
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                data_loader = CustomDatasetDataLoader(opt, rank=rank)
         
     | 
| 59 | 
         
            +
                dataset = data_loader.load_data()
         
     | 
| 60 | 
         
            +
                return dataset
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            class CustomDatasetDataLoader():
         
     | 
| 63 | 
         
            +
                """Wrapper class of Dataset class that performs multi-threaded data loading"""
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def __init__(self, opt, rank=0):
         
     | 
| 66 | 
         
            +
                    """Initialize this class
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    Step 1: create a dataset instance given the name [dataset_mode]
         
     | 
| 69 | 
         
            +
                    Step 2: create a multi-threaded data loader.
         
     | 
| 70 | 
         
            +
                    """
         
     | 
| 71 | 
         
            +
                    self.opt = opt
         
     | 
| 72 | 
         
            +
                    dataset_class = find_dataset_using_name(opt.dataset_mode)
         
     | 
| 73 | 
         
            +
                    self.dataset = dataset_class(opt)
         
     | 
| 74 | 
         
            +
                    self.sampler = None
         
     | 
| 75 | 
         
            +
                    print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
         
     | 
| 76 | 
         
            +
                    if opt.use_ddp and opt.isTrain:
         
     | 
| 77 | 
         
            +
                        world_size = opt.world_size
         
     | 
| 78 | 
         
            +
                        self.sampler = torch.utils.data.distributed.DistributedSampler(
         
     | 
| 79 | 
         
            +
                                self.dataset,
         
     | 
| 80 | 
         
            +
                                num_replicas=world_size,
         
     | 
| 81 | 
         
            +
                                rank=rank,
         
     | 
| 82 | 
         
            +
                                shuffle=not opt.serial_batches
         
     | 
| 83 | 
         
            +
                            )
         
     | 
| 84 | 
         
            +
                        self.dataloader = torch.utils.data.DataLoader(
         
     | 
| 85 | 
         
            +
                                    self.dataset,
         
     | 
| 86 | 
         
            +
                                    sampler=self.sampler,
         
     | 
| 87 | 
         
            +
                                    num_workers=int(opt.num_threads / world_size), 
         
     | 
| 88 | 
         
            +
                                    batch_size=int(opt.batch_size / world_size), 
         
     | 
| 89 | 
         
            +
                                    drop_last=True)
         
     | 
| 90 | 
         
            +
                    else:
         
     | 
| 91 | 
         
            +
                        self.dataloader = torch.utils.data.DataLoader(
         
     | 
| 92 | 
         
            +
                            self.dataset,
         
     | 
| 93 | 
         
            +
                            batch_size=opt.batch_size,
         
     | 
| 94 | 
         
            +
                            shuffle=(not opt.serial_batches) and opt.isTrain,
         
     | 
| 95 | 
         
            +
                            num_workers=int(opt.num_threads),
         
     | 
| 96 | 
         
            +
                            drop_last=True
         
     | 
| 97 | 
         
            +
                        )
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def set_epoch(self, epoch):
         
     | 
| 100 | 
         
            +
                    self.dataset.current_epoch = epoch
         
     | 
| 101 | 
         
            +
                    if self.sampler is not None:
         
     | 
| 102 | 
         
            +
                        self.sampler.set_epoch(epoch)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def load_data(self):
         
     | 
| 105 | 
         
            +
                    return self
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def __len__(self):
         
     | 
| 108 | 
         
            +
                    """Return the number of data in the dataset"""
         
     | 
| 109 | 
         
            +
                    return min(len(self.dataset), self.opt.max_dataset_size)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def __iter__(self):
         
     | 
| 112 | 
         
            +
                    """Return a batch of data"""
         
     | 
| 113 | 
         
            +
                    for i, data in enumerate(self.dataloader):
         
     | 
| 114 | 
         
            +
                        if i * self.opt.batch_size >= self.opt.max_dataset_size:
         
     | 
| 115 | 
         
            +
                            break
         
     | 
| 116 | 
         
            +
                        yield data
         
     | 
    	
        src/face3d/data/base_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch.utils.data as data
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            import torchvision.transforms as transforms
         
     | 
| 10 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class BaseDataset(data.Dataset, ABC):
         
     | 
| 14 | 
         
            +
                """This class is an abstract base class (ABC) for datasets.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                To create a subclass, you need to implement the following four functions:
         
     | 
| 17 | 
         
            +
                -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
         
     | 
| 18 | 
         
            +
                -- <__len__>:                       return the size of dataset.
         
     | 
| 19 | 
         
            +
                -- <__getitem__>:                   get a data point.
         
     | 
| 20 | 
         
            +
                -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
         
     | 
| 21 | 
         
            +
                """
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __init__(self, opt):
         
     | 
| 24 | 
         
            +
                    """Initialize the class; save the options in the class
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    Parameters:
         
     | 
| 27 | 
         
            +
                        opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 28 | 
         
            +
                    """
         
     | 
| 29 | 
         
            +
                    self.opt = opt
         
     | 
| 30 | 
         
            +
                    # self.root = opt.dataroot
         
     | 
| 31 | 
         
            +
                    self.current_epoch = 0
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                @staticmethod
         
     | 
| 34 | 
         
            +
                def modify_commandline_options(parser, is_train):
         
     | 
| 35 | 
         
            +
                    """Add new dataset-specific options, and rewrite default values for existing options.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    Parameters:
         
     | 
| 38 | 
         
            +
                        parser          -- original option parser
         
     | 
| 39 | 
         
            +
                        is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    Returns:
         
     | 
| 42 | 
         
            +
                        the modified parser.
         
     | 
| 43 | 
         
            +
                    """
         
     | 
| 44 | 
         
            +
                    return parser
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                @abstractmethod
         
     | 
| 47 | 
         
            +
                def __len__(self):
         
     | 
| 48 | 
         
            +
                    """Return the total number of images in the dataset."""
         
     | 
| 49 | 
         
            +
                    return 0
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                @abstractmethod
         
     | 
| 52 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 53 | 
         
            +
                    """Return a data point and its metadata information.
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    Parameters:
         
     | 
| 56 | 
         
            +
                        index - - a random integer for data indexing
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    Returns:
         
     | 
| 59 | 
         
            +
                        a dictionary of data with their names. It ususally contains the data itself and its metadata information.
         
     | 
| 60 | 
         
            +
                    """
         
     | 
| 61 | 
         
            +
                    pass
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def get_transform(grayscale=False):
         
     | 
| 65 | 
         
            +
                transform_list = []
         
     | 
| 66 | 
         
            +
                if grayscale:
         
     | 
| 67 | 
         
            +
                    transform_list.append(transforms.Grayscale(1))
         
     | 
| 68 | 
         
            +
                transform_list += [transforms.ToTensor()]
         
     | 
| 69 | 
         
            +
                return transforms.Compose(transform_list)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def get_affine_mat(opt, size):
         
     | 
| 72 | 
         
            +
                shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
         
     | 
| 73 | 
         
            +
                w, h = size
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                if 'shift' in opt.preprocess:
         
     | 
| 76 | 
         
            +
                    shift_pixs = int(opt.shift_pixs)
         
     | 
| 77 | 
         
            +
                    shift_x = random.randint(-shift_pixs, shift_pixs)
         
     | 
| 78 | 
         
            +
                    shift_y = random.randint(-shift_pixs, shift_pixs)
         
     | 
| 79 | 
         
            +
                if 'scale' in opt.preprocess:
         
     | 
| 80 | 
         
            +
                    scale = 1 + opt.scale_delta * (2 * random.random() - 1)
         
     | 
| 81 | 
         
            +
                if 'rot' in opt.preprocess:
         
     | 
| 82 | 
         
            +
                    rot_angle = opt.rot_angle * (2 * random.random() - 1)
         
     | 
| 83 | 
         
            +
                    rot_rad = -rot_angle * np.pi/180
         
     | 
| 84 | 
         
            +
                if 'flip' in opt.preprocess:
         
     | 
| 85 | 
         
            +
                    flip = random.random() > 0.5
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
         
     | 
| 88 | 
         
            +
                flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
         
     | 
| 89 | 
         
            +
                shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
         
     | 
| 90 | 
         
            +
                rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
         
     | 
| 91 | 
         
            +
                scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
         
     | 
| 92 | 
         
            +
                shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
         
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
                affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin    
         
     | 
| 95 | 
         
            +
                affine_inv = np.linalg.inv(affine)
         
     | 
| 96 | 
         
            +
                return affine, affine_inv, flip
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
         
     | 
| 99 | 
         
            +
                return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            def apply_lm_affine(landmark, affine, flip, size):
         
     | 
| 102 | 
         
            +
                _, h = size
         
     | 
| 103 | 
         
            +
                lm = landmark.copy()
         
     | 
| 104 | 
         
            +
                lm[:, 1] = h - 1 - lm[:, 1]
         
     | 
| 105 | 
         
            +
                lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
         
     | 
| 106 | 
         
            +
                lm = lm @ np.transpose(affine)
         
     | 
| 107 | 
         
            +
                lm[:, :2] = lm[:, :2] / lm[:, 2:]
         
     | 
| 108 | 
         
            +
                lm = lm[:, :2]
         
     | 
| 109 | 
         
            +
                lm[:, 1] = h - 1 - lm[:, 1]
         
     | 
| 110 | 
         
            +
                if flip:
         
     | 
| 111 | 
         
            +
                    lm_ = lm.copy()
         
     | 
| 112 | 
         
            +
                    lm_[:17] = lm[16::-1]
         
     | 
| 113 | 
         
            +
                    lm_[17:22] = lm[26:21:-1]
         
     | 
| 114 | 
         
            +
                    lm_[22:27] = lm[21:16:-1]
         
     | 
| 115 | 
         
            +
                    lm_[31:36] = lm[35:30:-1]
         
     | 
| 116 | 
         
            +
                    lm_[36:40] = lm[45:41:-1]
         
     | 
| 117 | 
         
            +
                    lm_[40:42] = lm[47:45:-1]
         
     | 
| 118 | 
         
            +
                    lm_[42:46] = lm[39:35:-1]
         
     | 
| 119 | 
         
            +
                    lm_[46:48] = lm[41:39:-1]
         
     | 
| 120 | 
         
            +
                    lm_[48:55] = lm[54:47:-1]
         
     | 
| 121 | 
         
            +
                    lm_[55:60] = lm[59:54:-1]
         
     | 
| 122 | 
         
            +
                    lm_[60:65] = lm[64:59:-1]
         
     | 
| 123 | 
         
            +
                    lm_[65:68] = lm[67:64:-1]
         
     | 
| 124 | 
         
            +
                    lm = lm_
         
     | 
| 125 | 
         
            +
                return lm
         
     | 
    	
        src/face3d/data/flist_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This script defines the custom dataset for Deep3DFaceRecon_pytorch
         
     | 
| 2 | 
         
            +
            """
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import os.path
         
     | 
| 5 | 
         
            +
            from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
         
     | 
| 6 | 
         
            +
            from data.image_folder import make_dataset
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            import random
         
     | 
| 9 | 
         
            +
            import util.util as util
         
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            import json
         
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            from scipy.io import loadmat, savemat
         
     | 
| 14 | 
         
            +
            import pickle
         
     | 
| 15 | 
         
            +
            from util.preprocess import align_img, estimate_norm
         
     | 
| 16 | 
         
            +
            from util.load_mats import load_lm3d
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def default_flist_reader(flist):
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
                flist format: impath label\nimpath label\n ...(same to caffe's filelist)
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                imlist = []
         
     | 
| 24 | 
         
            +
                with open(flist, 'r') as rf:
         
     | 
| 25 | 
         
            +
                    for line in rf.readlines():
         
     | 
| 26 | 
         
            +
                        impath = line.strip()
         
     | 
| 27 | 
         
            +
                        imlist.append(impath)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                return imlist
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def jason_flist_reader(flist):
         
     | 
| 32 | 
         
            +
                with open(flist, 'r') as fp:
         
     | 
| 33 | 
         
            +
                    info = json.load(fp)
         
     | 
| 34 | 
         
            +
                return info
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def parse_label(label):
         
     | 
| 37 | 
         
            +
                return torch.tensor(np.array(label).astype(np.float32))
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class FlistDataset(BaseDataset):
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
                It requires one directories to host training images '/path/to/data/train'
         
     | 
| 43 | 
         
            +
                You can train the model with the dataset flag '--dataroot /path/to/data'.
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def __init__(self, opt):
         
     | 
| 47 | 
         
            +
                    """Initialize this dataset class.
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    Parameters:
         
     | 
| 50 | 
         
            +
                        opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 51 | 
         
            +
                    """
         
     | 
| 52 | 
         
            +
                    BaseDataset.__init__(self, opt)
         
     | 
| 53 | 
         
            +
                    
         
     | 
| 54 | 
         
            +
                    self.lm3d_std = load_lm3d(opt.bfm_folder)
         
     | 
| 55 | 
         
            +
                    
         
     | 
| 56 | 
         
            +
                    msk_names = default_flist_reader(opt.flist)
         
     | 
| 57 | 
         
            +
                    self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    self.size = len(self.msk_paths) 
         
     | 
| 60 | 
         
            +
                    self.opt = opt
         
     | 
| 61 | 
         
            +
                    
         
     | 
| 62 | 
         
            +
                    self.name = 'train' if opt.isTrain else 'val'
         
     | 
| 63 | 
         
            +
                    if '_' in opt.flist:
         
     | 
| 64 | 
         
            +
                        self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
         
     | 
| 65 | 
         
            +
                    
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 68 | 
         
            +
                    """Return a data point and its metadata information.
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    Parameters:
         
     | 
| 71 | 
         
            +
                        index (int)      -- a random integer for data indexing
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    Returns a dictionary that contains A, B, A_paths and B_paths
         
     | 
| 74 | 
         
            +
                        img (tensor)       -- an image in the input domain
         
     | 
| 75 | 
         
            +
                        msk (tensor)       -- its corresponding attention mask
         
     | 
| 76 | 
         
            +
                        lm  (tensor)       -- its corresponding 3d landmarks
         
     | 
| 77 | 
         
            +
                        im_paths (str)     -- image paths
         
     | 
| 78 | 
         
            +
                        aug_flag (bool)    -- a flag used to tell whether its raw or augmented
         
     | 
| 79 | 
         
            +
                    """
         
     | 
| 80 | 
         
            +
                    msk_path = self.msk_paths[index % self.size]  # make sure index is within then range
         
     | 
| 81 | 
         
            +
                    img_path = msk_path.replace('mask/', '')
         
     | 
| 82 | 
         
            +
                    lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    raw_img = Image.open(img_path).convert('RGB')
         
     | 
| 85 | 
         
            +
                    raw_msk = Image.open(msk_path).convert('RGB')
         
     | 
| 86 | 
         
            +
                    raw_lm = np.loadtxt(lm_path).astype(np.float32)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
         
     | 
| 89 | 
         
            +
                    
         
     | 
| 90 | 
         
            +
                    aug_flag = self.opt.use_aug and self.opt.isTrain
         
     | 
| 91 | 
         
            +
                    if aug_flag:
         
     | 
| 92 | 
         
            +
                        img, lm, msk = self._augmentation(img, lm, self.opt, msk)
         
     | 
| 93 | 
         
            +
                    
         
     | 
| 94 | 
         
            +
                    _, H = img.size
         
     | 
| 95 | 
         
            +
                    M = estimate_norm(lm, H)
         
     | 
| 96 | 
         
            +
                    transform = get_transform()
         
     | 
| 97 | 
         
            +
                    img_tensor = transform(img)
         
     | 
| 98 | 
         
            +
                    msk_tensor = transform(msk)[:1, ...]
         
     | 
| 99 | 
         
            +
                    lm_tensor = parse_label(lm)
         
     | 
| 100 | 
         
            +
                    M_tensor = parse_label(M)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    return {'imgs': img_tensor, 
         
     | 
| 104 | 
         
            +
                            'lms': lm_tensor, 
         
     | 
| 105 | 
         
            +
                            'msks': msk_tensor, 
         
     | 
| 106 | 
         
            +
                            'M': M_tensor,
         
     | 
| 107 | 
         
            +
                            'im_paths': img_path, 
         
     | 
| 108 | 
         
            +
                            'aug_flag': aug_flag,
         
     | 
| 109 | 
         
            +
                            'dataset': self.name}
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def _augmentation(self, img, lm, opt, msk=None):
         
     | 
| 112 | 
         
            +
                    affine, affine_inv, flip = get_affine_mat(opt, img.size)
         
     | 
| 113 | 
         
            +
                    img = apply_img_affine(img, affine_inv)
         
     | 
| 114 | 
         
            +
                    lm = apply_lm_affine(lm, affine, flip, img.size)
         
     | 
| 115 | 
         
            +
                    if msk is not None:
         
     | 
| 116 | 
         
            +
                        msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
         
     | 
| 117 | 
         
            +
                    return img, lm, msk
         
     | 
| 118 | 
         
            +
                
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def __len__(self):
         
     | 
| 123 | 
         
            +
                    """Return the total number of images in the dataset.
         
     | 
| 124 | 
         
            +
                    """
         
     | 
| 125 | 
         
            +
                    return self.size
         
     | 
    	
        src/face3d/data/image_folder.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """A modified image folder class
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
         
     | 
| 4 | 
         
            +
            so that this class can load images from both current directory and its subdirectories.
         
     | 
| 5 | 
         
            +
            """
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch.utils.data as data
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import os.path
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            IMG_EXTENSIONS = [
         
     | 
| 14 | 
         
            +
                '.jpg', '.JPG', '.jpeg', '.JPEG',
         
     | 
| 15 | 
         
            +
                '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
         
     | 
| 16 | 
         
            +
                '.tif', '.TIF', '.tiff', '.TIFF',
         
     | 
| 17 | 
         
            +
            ]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def is_image_file(filename):
         
     | 
| 21 | 
         
            +
                return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def make_dataset(dir, max_dataset_size=float("inf")):
         
     | 
| 25 | 
         
            +
                images = []
         
     | 
| 26 | 
         
            +
                assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
         
     | 
| 29 | 
         
            +
                    for fname in fnames:
         
     | 
| 30 | 
         
            +
                        if is_image_file(fname):
         
     | 
| 31 | 
         
            +
                            path = os.path.join(root, fname)
         
     | 
| 32 | 
         
            +
                            images.append(path)
         
     | 
| 33 | 
         
            +
                return images[:min(max_dataset_size, len(images))]
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def default_loader(path):
         
     | 
| 37 | 
         
            +
                return Image.open(path).convert('RGB')
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class ImageFolder(data.Dataset):
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def __init__(self, root, transform=None, return_paths=False,
         
     | 
| 43 | 
         
            +
                             loader=default_loader):
         
     | 
| 44 | 
         
            +
                    imgs = make_dataset(root)
         
     | 
| 45 | 
         
            +
                    if len(imgs) == 0:
         
     | 
| 46 | 
         
            +
                        raise(RuntimeError("Found 0 images in: " + root + "\n"
         
     | 
| 47 | 
         
            +
                                           "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    self.root = root
         
     | 
| 50 | 
         
            +
                    self.imgs = imgs
         
     | 
| 51 | 
         
            +
                    self.transform = transform
         
     | 
| 52 | 
         
            +
                    self.return_paths = return_paths
         
     | 
| 53 | 
         
            +
                    self.loader = loader
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 56 | 
         
            +
                    path = self.imgs[index]
         
     | 
| 57 | 
         
            +
                    img = self.loader(path)
         
     | 
| 58 | 
         
            +
                    if self.transform is not None:
         
     | 
| 59 | 
         
            +
                        img = self.transform(img)
         
     | 
| 60 | 
         
            +
                    if self.return_paths:
         
     | 
| 61 | 
         
            +
                        return img, path
         
     | 
| 62 | 
         
            +
                    else:
         
     | 
| 63 | 
         
            +
                        return img
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def __len__(self):
         
     | 
| 66 | 
         
            +
                    return len(self.imgs)
         
     | 
    	
        src/face3d/data/template_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,75 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Dataset class template
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            This module provides a template for users to implement custom datasets.
         
     | 
| 4 | 
         
            +
            You can specify '--dataset_mode template' to use this dataset.
         
     | 
| 5 | 
         
            +
            The class name should be consistent with both the filename and its dataset_mode option.
         
     | 
| 6 | 
         
            +
            The filename should be <dataset_mode>_dataset.py
         
     | 
| 7 | 
         
            +
            The class name should be <Dataset_mode>Dataset.py
         
     | 
| 8 | 
         
            +
            You need to implement the following functions:
         
     | 
| 9 | 
         
            +
                -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
         
     | 
| 10 | 
         
            +
                -- <__init__>: Initialize this dataset class.
         
     | 
| 11 | 
         
            +
                -- <__getitem__>: Return a data point and its metadata information.
         
     | 
| 12 | 
         
            +
                -- <__len__>: Return the number of images.
         
     | 
| 13 | 
         
            +
            """
         
     | 
| 14 | 
         
            +
            from data.base_dataset import BaseDataset, get_transform
         
     | 
| 15 | 
         
            +
            # from data.image_folder import make_dataset
         
     | 
| 16 | 
         
            +
            # from PIL import Image
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class TemplateDataset(BaseDataset):
         
     | 
| 20 | 
         
            +
                """A template dataset class for you to implement custom datasets."""
         
     | 
| 21 | 
         
            +
                @staticmethod
         
     | 
| 22 | 
         
            +
                def modify_commandline_options(parser, is_train):
         
     | 
| 23 | 
         
            +
                    """Add new dataset-specific options, and rewrite default values for existing options.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    Parameters:
         
     | 
| 26 | 
         
            +
                        parser          -- original option parser
         
     | 
| 27 | 
         
            +
                        is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    Returns:
         
     | 
| 30 | 
         
            +
                        the modified parser.
         
     | 
| 31 | 
         
            +
                    """
         
     | 
| 32 | 
         
            +
                    parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
         
     | 
| 33 | 
         
            +
                    parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0)  # specify dataset-specific default values
         
     | 
| 34 | 
         
            +
                    return parser
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def __init__(self, opt):
         
     | 
| 37 | 
         
            +
                    """Initialize this dataset class.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    Parameters:
         
     | 
| 40 | 
         
            +
                        opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    A few things can be done here.
         
     | 
| 43 | 
         
            +
                    - save the options (have been done in BaseDataset)
         
     | 
| 44 | 
         
            +
                    - get image paths and meta information of the dataset.
         
     | 
| 45 | 
         
            +
                    - define the image transformation.
         
     | 
| 46 | 
         
            +
                    """
         
     | 
| 47 | 
         
            +
                    # save the option and dataset root
         
     | 
| 48 | 
         
            +
                    BaseDataset.__init__(self, opt)
         
     | 
| 49 | 
         
            +
                    # get the image paths of your dataset;
         
     | 
| 50 | 
         
            +
                    self.image_paths = []  # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
         
     | 
| 51 | 
         
            +
                    # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
         
     | 
| 52 | 
         
            +
                    self.transform = get_transform(opt)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 55 | 
         
            +
                    """Return a data point and its metadata information.
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    Parameters:
         
     | 
| 58 | 
         
            +
                        index -- a random integer for data indexing
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    Returns:
         
     | 
| 61 | 
         
            +
                        a dictionary of data with their names. It usually contains the data itself and its metadata information.
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    Step 1: get a random image path: e.g., path = self.image_paths[index]
         
     | 
| 64 | 
         
            +
                    Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
         
     | 
| 65 | 
         
            +
                    Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
         
     | 
| 66 | 
         
            +
                    Step 4: return a data point as a dictionary.
         
     | 
| 67 | 
         
            +
                    """
         
     | 
| 68 | 
         
            +
                    path = 'temp'    # needs to be a string
         
     | 
| 69 | 
         
            +
                    data_A = None    # needs to be a tensor
         
     | 
| 70 | 
         
            +
                    data_B = None    # needs to be a tensor
         
     | 
| 71 | 
         
            +
                    return {'data_A': data_A, 'data_B': data_B, 'path': path}
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def __len__(self):
         
     | 
| 74 | 
         
            +
                    """Return the total number of images."""
         
     | 
| 75 | 
         
            +
                    return len(self.image_paths)
         
     | 
    	
        src/face3d/extract_kp_videos.py
    ADDED
    
    | 
         @@ -0,0 +1,107 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import cv2
         
     | 
| 3 | 
         
            +
            import time
         
     | 
| 4 | 
         
            +
            import glob
         
     | 
| 5 | 
         
            +
            import argparse
         
     | 
| 6 | 
         
            +
            import face_alignment
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            from tqdm import tqdm
         
     | 
| 10 | 
         
            +
            from itertools import cycle
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from torch.multiprocessing import Pool, Process, set_start_method
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class KeypointExtractor():
         
     | 
| 15 | 
         
            +
                def __init__(self):
         
     | 
| 16 | 
         
            +
                    self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D)   
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def extract_keypoint(self, images, name=None, info=True):
         
     | 
| 19 | 
         
            +
                    if isinstance(images, list):
         
     | 
| 20 | 
         
            +
                        keypoints = []
         
     | 
| 21 | 
         
            +
                        if info:
         
     | 
| 22 | 
         
            +
                            i_range = tqdm(images,desc='landmark Det:')
         
     | 
| 23 | 
         
            +
                        else:
         
     | 
| 24 | 
         
            +
                            i_range = images
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                        for image in i_range:
         
     | 
| 27 | 
         
            +
                            current_kp = self.extract_keypoint(image)
         
     | 
| 28 | 
         
            +
                            if np.mean(current_kp) == -1 and keypoints:
         
     | 
| 29 | 
         
            +
                                keypoints.append(keypoints[-1])
         
     | 
| 30 | 
         
            +
                            else:
         
     | 
| 31 | 
         
            +
                                keypoints.append(current_kp[None])
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                        keypoints = np.concatenate(keypoints, 0)
         
     | 
| 34 | 
         
            +
                        np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
         
     | 
| 35 | 
         
            +
                        return keypoints
         
     | 
| 36 | 
         
            +
                    else:
         
     | 
| 37 | 
         
            +
                        while True:
         
     | 
| 38 | 
         
            +
                            try:
         
     | 
| 39 | 
         
            +
                                keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
         
     | 
| 40 | 
         
            +
                                break
         
     | 
| 41 | 
         
            +
                            except RuntimeError as e:
         
     | 
| 42 | 
         
            +
                                if str(e).startswith('CUDA'):
         
     | 
| 43 | 
         
            +
                                    print("Warning: out of memory, sleep for 1s")
         
     | 
| 44 | 
         
            +
                                    time.sleep(1)
         
     | 
| 45 | 
         
            +
                                else:
         
     | 
| 46 | 
         
            +
                                    print(e)
         
     | 
| 47 | 
         
            +
                                    break    
         
     | 
| 48 | 
         
            +
                            except TypeError:
         
     | 
| 49 | 
         
            +
                                print('No face detected in this image')
         
     | 
| 50 | 
         
            +
                                shape = [68, 2]
         
     | 
| 51 | 
         
            +
                                keypoints = -1. * np.ones(shape)                    
         
     | 
| 52 | 
         
            +
                                break
         
     | 
| 53 | 
         
            +
                        if name is not None:
         
     | 
| 54 | 
         
            +
                            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
         
     | 
| 55 | 
         
            +
                        return keypoints
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def read_video(filename):
         
     | 
| 58 | 
         
            +
                frames = []
         
     | 
| 59 | 
         
            +
                cap = cv2.VideoCapture(filename)
         
     | 
| 60 | 
         
            +
                while cap.isOpened():
         
     | 
| 61 | 
         
            +
                    ret, frame = cap.read()
         
     | 
| 62 | 
         
            +
                    if ret:
         
     | 
| 63 | 
         
            +
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         
     | 
| 64 | 
         
            +
                        frame = Image.fromarray(frame)
         
     | 
| 65 | 
         
            +
                        frames.append(frame)
         
     | 
| 66 | 
         
            +
                    else:
         
     | 
| 67 | 
         
            +
                        break
         
     | 
| 68 | 
         
            +
                cap.release()
         
     | 
| 69 | 
         
            +
                return frames
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def run(data):
         
     | 
| 72 | 
         
            +
                filename, opt, device = data
         
     | 
| 73 | 
         
            +
                os.environ['CUDA_VISIBLE_DEVICES'] = device
         
     | 
| 74 | 
         
            +
                kp_extractor = KeypointExtractor()
         
     | 
| 75 | 
         
            +
                images = read_video(filename)
         
     | 
| 76 | 
         
            +
                name = filename.split('/')[-2:]
         
     | 
| 77 | 
         
            +
                os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
         
     | 
| 78 | 
         
            +
                kp_extractor.extract_keypoint(
         
     | 
| 79 | 
         
            +
                    images, 
         
     | 
| 80 | 
         
            +
                    name=os.path.join(opt.output_dir, name[-2], name[-1])
         
     | 
| 81 | 
         
            +
                )
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 84 | 
         
            +
                set_start_method('spawn')
         
     | 
| 85 | 
         
            +
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
         
     | 
| 86 | 
         
            +
                parser.add_argument('--input_dir', type=str, help='the folder of the input files')
         
     | 
| 87 | 
         
            +
                parser.add_argument('--output_dir', type=str, help='the folder of the output files')
         
     | 
| 88 | 
         
            +
                parser.add_argument('--device_ids', type=str, default='0,1')
         
     | 
| 89 | 
         
            +
                parser.add_argument('--workers', type=int, default=4)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                opt = parser.parse_args()
         
     | 
| 92 | 
         
            +
                filenames = list()
         
     | 
| 93 | 
         
            +
                VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
         
     | 
| 94 | 
         
            +
                VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
         
     | 
| 95 | 
         
            +
                extensions = VIDEO_EXTENSIONS
         
     | 
| 96 | 
         
            +
                
         
     | 
| 97 | 
         
            +
                for ext in extensions:
         
     | 
| 98 | 
         
            +
                    os.listdir(f'{opt.input_dir}')
         
     | 
| 99 | 
         
            +
                    print(f'{opt.input_dir}/*.{ext}')
         
     | 
| 100 | 
         
            +
                    filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
         
     | 
| 101 | 
         
            +
                print('Total number of videos:', len(filenames))
         
     | 
| 102 | 
         
            +
                pool = Pool(opt.workers)
         
     | 
| 103 | 
         
            +
                args_list = cycle([opt])
         
     | 
| 104 | 
         
            +
                device_ids = opt.device_ids.split(",")
         
     | 
| 105 | 
         
            +
                device_ids = cycle(device_ids)
         
     | 
| 106 | 
         
            +
                for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
         
     | 
| 107 | 
         
            +
                    None
         
     | 
    	
        src/face3d/models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,67 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This package contains modules related to objective functions, optimizations, and network architectures.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
         
     | 
| 4 | 
         
            +
            You need to implement the following five functions:
         
     | 
| 5 | 
         
            +
                -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
         
     | 
| 6 | 
         
            +
                -- <set_input>:                     unpack data from dataset and apply preprocessing.
         
     | 
| 7 | 
         
            +
                -- <forward>:                       produce intermediate results.
         
     | 
| 8 | 
         
            +
                -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
         
     | 
| 9 | 
         
            +
                -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            In the function <__init__>, you need to define four lists:
         
     | 
| 12 | 
         
            +
                -- self.loss_names (str list):          specify the training losses that you want to plot and save.
         
     | 
| 13 | 
         
            +
                -- self.model_names (str list):         define networks used in our training.
         
     | 
| 14 | 
         
            +
                -- self.visual_names (str list):        specify the images that you want to display and save.
         
     | 
| 15 | 
         
            +
                -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            Now you can use the model class by specifying flag '--model dummy'.
         
     | 
| 18 | 
         
            +
            See our template model class 'template_model.py' for more details.
         
     | 
| 19 | 
         
            +
            """
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import importlib
         
     | 
| 22 | 
         
            +
            from src.face3d.models.base_model import BaseModel
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def find_model_using_name(model_name):
         
     | 
| 26 | 
         
            +
                """Import the module "models/[model_name]_model.py".
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                In the file, the class called DatasetNameModel() will
         
     | 
| 29 | 
         
            +
                be instantiated. It has to be a subclass of BaseModel,
         
     | 
| 30 | 
         
            +
                and it is case-insensitive.
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                model_filename = "face3d.models." + model_name + "_model"
         
     | 
| 33 | 
         
            +
                modellib = importlib.import_module(model_filename)
         
     | 
| 34 | 
         
            +
                model = None
         
     | 
| 35 | 
         
            +
                target_model_name = model_name.replace('_', '') + 'model'
         
     | 
| 36 | 
         
            +
                for name, cls in modellib.__dict__.items():
         
     | 
| 37 | 
         
            +
                    if name.lower() == target_model_name.lower() \
         
     | 
| 38 | 
         
            +
                       and issubclass(cls, BaseModel):
         
     | 
| 39 | 
         
            +
                        model = cls
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                if model is None:
         
     | 
| 42 | 
         
            +
                    print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
         
     | 
| 43 | 
         
            +
                    exit(0)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                return model
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def get_option_setter(model_name):
         
     | 
| 49 | 
         
            +
                """Return the static method <modify_commandline_options> of the model class."""
         
     | 
| 50 | 
         
            +
                model_class = find_model_using_name(model_name)
         
     | 
| 51 | 
         
            +
                return model_class.modify_commandline_options
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def create_model(opt):
         
     | 
| 55 | 
         
            +
                """Create a model given the option.
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                This function warps the class CustomDatasetDataLoader.
         
     | 
| 58 | 
         
            +
                This is the main interface between this package and 'train.py'/'test.py'
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                Example:
         
     | 
| 61 | 
         
            +
                    >>> from models import create_model
         
     | 
| 62 | 
         
            +
                    >>> model = create_model(opt)
         
     | 
| 63 | 
         
            +
                """
         
     | 
| 64 | 
         
            +
                model = find_model_using_name(opt.model)
         
     | 
| 65 | 
         
            +
                instance = model(opt)
         
     | 
| 66 | 
         
            +
                print("model [%s] was created" % type(instance).__name__)
         
     | 
| 67 | 
         
            +
                return instance
         
     |